In [None]:
def run_utils():
    # Get the GPU device name.
    device_name = tf.test.gpu_device_name()
    # The device name should look like the following:
    if device_name == '/device:GPU:0':
        print('Found GPU at: {}'.format(device_name))
    else:
        raise SystemError('GPU device not found')

    device = None
    # If there's a GPU available...
    if torch.cuda.is_available():    
        # Tell PyTorch to use the GPU.    
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0))
    # If not...
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    return device


# Function to calculate the accuracy of our predictions vs labels
def predict(P_ideology, ideology_labels):
    predict_labels = torch.argmax(P_ideology, 1)
    target_labels = torch.argmax(ideology_labels, 1)
    

    true_predict_count = len((torch.eq(predict_labels, target_labels)).nonzero().flatten())
    accuracy = true_predict_count / len(predict_labels)
    
    return accuracy


# Function to calculate the accuracy of our predictions vs labels
def predict_binary(P_ideology, ideology_labels):
    predict_labels = np.round(P_ideology)
    predict_labels = predict_labels.int()

    true_predict_count = (torch.eq(predict_labels, ideology_labels)).sum()
    true_predict_count = true_predict_count.numpy()
    #print(true_predict_count)
    
    accuracy = true_predict_count / len(predict_labels)
    
    return accuracy


def find_maxLen_doc(data, tokenizer):
    max_len = 0
    # For every sentence...
    for sent in data:
        # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
        input_ids = tokenizer.encode(sent, add_special_tokens=True)
        # Update the maximum sentence length.
        max_len = max(max_len, len(input_ids))

    print('Max sentence length: ', max_len)

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def load_tokenizer(model):
    tokenizer = None
    from transformers import AutoTokenizer, DistilBertTokenizer, BertTokenizer, RobertaTokenizer, AutoModelWithLMHead
    tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=True)
    #tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    return tokenizer

def load_dataset(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header=0, names=['qID', 'q_ideology', 'ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])       

    df['docCont'] = df['docCont'].str.lower()
    #df['topic'] = df['topic'].str.lower()
    df['Q'] = df['Q'].str.lower()
    df['title'] = df['title'].str.lower()
    
    #df.insert(0, "stanceStr", df['stance'], True)
    #df["stanceStr"] = df["stanceStr"].replace({1: "Pro", 0: "Agst"})
    
    print("Train")
    print ("Con", df[df.ideology == 0].shape[0])
    print ("Lib", df[df.ideology == 1].shape[0])
    print ("Pro", df[df.stance == 1].shape[0])
    print ("Against", df[df.stance == 0].shape[0])


    return df

def load_dataset_ambigious(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header=0, names=['qID', 'docID', 'stance', 'Ambigious', 'ideology', 'docCont', 'Q', 'title'])       

    df['docCont'] = df['docCont'].str.lower()
    #df['topic'] = df['topic'].str.lower()
    df['Q'] = df['Q'].str.lower()
    df['title'] = df['title'].str.lower()
    
    df = df.astype(str)
    
    #df.insert(0, "stanceStr", df['stance'], True)
    #df["stanceStr"] = df["stanceStr"].replace({1: "Pro", 0: "Agst"})
    
    print("Train")
    print ("Ambigious", df[df.Ambigious == "1"].shape[0])
    print ("Non-ambigious", df[df.Ambigious == "0"].shape[0])
    
    print(df.dtypes)

    return df

def load_dataset_first(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header = 0, names=['qID', 'ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])

    df_q = pd.read_csv(path.replace('final.tsv', 'final_onlyqID.tsv'), delimiter='\t', header=0, names=['qID', 'orientation'])
    df_q["orientation"] = df_q["orientation"].replace({-1: "Lib", 1: "Con"})
           
    df = df.drop('qID', axis=1)
    df.insert(0, "qID", df_q['qID'], True)
    df.insert(1, "q_ideology", df_q['orientation'], True)
  
    df["stance"] = df["stance"].replace({"-1": 0, "1": 1, "Pro": 1, "Agst": 0})
    
    return df

def sample_dataset_stance(df, seedVal):
    #create_determinism(seedVal)
    
    df_A = df[df['Ambigious'] == "1"]
    df_N = df[df['Ambigious'] == "0"]
    
    
    df_new = df_A.append(df_N, ignore_index = True)

    y_copy = df_new['Ambigious'].copy(deep=True)
    X_copy = df_new.drop('Ambigious', axis=1).copy(deep=True)
    
    X = pd.DataFrame (columns=['qID', 'docID', 'stance', 'ideology', 'docCont', 'Q', 'title'])
    y = pd.DataFrame (columns=['Ambigious'])
    
    X = X_copy
    y = y_copy
    
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)
    
    print(len(X_train))
    print(len(y_train))
    print(len(X_test))
    print(len(y_test))
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, shuffle=True)
    
    X_train.insert(2, "Ambigious", y_train.values) 
    X_val.insert(2, "Ambigious", y_val.values) 
    X_test.insert(2, "Ambigious", y_test.values)
    
    
    df_A = X_train[X_train['Ambigious'] == "1"]
    df_N = X_train[X_train['Ambigious'] == "0"]
    
    
    print("****Train****")
    print("Ambigious", df_A.shape[0])
    print("Not Ambigious", df_N.shape[0])
    
    
    df_A = X_test[X_test['Ambigious'] == "1"]
    df_N = X_test[X_test['Ambigious'] == "0"]
    
    print("****Test****")
    print("Ambigious", df_A.shape[0])
    print("Not Ambigious", df_N.shape[0])
    
    X_train.to_csv('./dataset/batches_cleaned/stance/train_serp.tsv', sep='\t', index=False)
    X_val.to_csv('./dataset/batches_cleaned/stance/val_serp.tsv', sep='\t', index=False)
    X_test.to_csv('./dataset/batches_cleaned/stance/test_serp.tsv', sep='\t', index=False)

    return X_train, X_val, X_test

def merge_datasets(df, dfVal, dfTest):
    from numpy import nan
    df = df.append(dfVal, ignore_index = True)
    df = df.append(dfTest, ignore_index = True)
    
    df.replace("", nan, inplace=True)
    df.replace(" ", nan, inplace=True)
    df.dropna(axis=0, how='any', thresh=None, subset=None, inplace=True)
    
    dfLabel = df['ideology'].copy(deep=True)
    df = df.drop('ideology', axis=1).copy(deep=True)
    
    return df, dfLabel

def preprocess_dataset_new_ideology_latest(df_new, testPer, seedVal):
    from sklearn.model_selection import train_test_split
    from pandas import DataFrame
    
    create_determinism(seedVal)
    
    #print("New dataset")
    #print(df_new['stance'].value_counts())

    y = df_new['ideology'].copy(deep=True)
    X = df_new.drop('ideology', axis=1).copy(deep=True)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testPer, shuffle=True, stratify=y)
            
    return X_train, y_train, X_test, y_test

def preprocess_dataset_new(df, dfLabels, testPer, seedVal):
    from sklearn.model_selection import train_test_split
    from pandas import DataFrame
    
    create_determinism(seedVal)

    df.insert(2, "ideology", dfLabels.values) 
    df = df.sort_values(by='Q')
     
    one_q_instances = []
    all_q_instances = {}
    
    curr_q = df.Q.values[0]
    for index, inst in df.iterrows():
        if curr_q == inst['Q']:
            one_q_instances.append(inst.values)
        else:
            all_q_instances[curr_q] = one_q_instances
            one_q_instances = []
            curr_q = inst['Q']
            one_q_instances.append(inst.values)
            
    
    X_train_allqueries = {}
    X_test_allqueries = {}
    y_train_allqueries = {}
    y_test_allqueries = {}
            
    for query in all_q_instances:
        this_query_instances = all_q_instances[query]
        
        df = DataFrame (this_query_instances, columns=['qID', 'q_ideology', 'stance', 'ideology', 'docCont', 'topic', 'Q', 'title'])
        
        y = df['ideology'].copy(deep=True)
        X = df.drop('ideology', axis=1).copy(deep=True)
    
        X_train_allqueries[query] = []
        y_train_allqueries[query] = []
        X_test_allqueries[query] = []
        y_test_allqueries[query] = []
        
        if len(X.index) > 1:
        
            con_count = len(df[df['ideology'] == 0])
            lib_count = len(df[df['ideology'] == 1])

            if con_count < 2 or lib_count < 2:
                X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True)
            else:
                if((con_count + lib_count)*testPer > 1):
                    X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True, stratify=y)
                else:
                    X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True)
        else:
            X_train_allqueries[query] = X
            y_train_allqueries[query] = y
    
    X_train = pd.DataFrame(columns=['qID', 'q_ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])
    y_train = pd.DataFrame(columns=['ideology'])
    
    X_test = pd.DataFrame(columns=['qID', 'q_ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])
    y_test = pd.DataFrame(columns=['ideology'])
    
    for query in X_train_allqueries:
        X_train = X_train.append(pd.DataFrame(X_train_allqueries[query]), ignore_index = True)
        y_train = y_train.append(pd.DataFrame(y_train_allqueries[query]), ignore_index = True)
    
    for query in X_test_allqueries:
        X_test = X_test.append(pd.DataFrame(X_test_allqueries[query]), ignore_index = True)
        y_test = y_test.append(pd.DataFrame(y_test_allqueries[query]), ignore_index = True)
    
    return X_train, y_train, X_test, y_test

def create_train_val_test_split(trainpath, valpath, testpath, testPer, valPer, seedVal):
    
    df = load_dataset("./dataset/ideology/train_samples.tsv")
    dfVal = load_dataset("./dataset/ideology/val_samples.tsv")
    dfTest = load_dataset("./dataset/ideology/test_samples.tsv")

    dfComp, dfCompLabel = merge_datasets(df, dfVal, dfTest)
    dfComp.insert(2, "ideology", dfCompLabel.values)
    
    df, dfLabel, dfTest, dfTestLabel = preprocess_dataset_new_ideology_latest(dfComp, 0.2, seedVal)
    df.insert(2, "ideology", dfLabel.values)
    
    df, dfLabel, dfVal, dfValLabel = preprocess_dataset_new_ideology_latest(df, 0.2, seedVal)
    
    df.insert(2, "ideology", dfLabel.values)
    dfVal.insert(2, "ideology", dfValLabel.values) 
    dfTest.insert(2, "ideology", dfTestLabel.values)
    
    df.to_csv('train_new.tsv', sep='\t', index=False)
    dfVal.to_csv('val_new.tsv', sep='\t', index=False)
    dfTest.to_csv('test_new.tsv', sep='\t', index=False)

def preprocess_ideologyOld(stance_labels, ideology_labels):
    t_stance = []
    t_ideology = []
    
    t_mmd_symbol = []
    t_mmd_symbol_ = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if s_label == 1 and i_label == 0: #pro-con
            t_stance.append([1,0])
            t_ideology.append([0,1])
            t_mmd_symbol.append(1)
            t_mmd_symbol_.append(0)
        elif s_label == 1 and i_label == 1: #pro-lib
            t_stance.append([1,0]) 
            t_ideology.append([1,0])
            t_mmd_symbol.append(1)
            t_mmd_symbol_.append(1)
        elif s_label == 0 and i_label == 0: #agst-con
            t_stance.append([0,1])
            t_ideology.append([0,1])
            t_mmd_symbol.append(0)
            t_mmd_symbol_.append(0)
        else: #agst-lib
            t_stance.append([0,1])
            t_ideology.append([1,0])
            t_mmd_symbol.append(0)
            t_mmd_symbol_.append(1)
            
    
    t_stance = torch.as_tensor(t_stance, dtype=torch.int32)
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    t_mmd_symbol  = torch.as_tensor(t_mmd_symbol, dtype=torch.float32)
    t_mmd_symbol_ = torch.as_tensor(t_mmd_symbol_, dtype=torch.float32)
    
    return t_stance, t_ideology, t_mmd_symbol, t_mmd_symbol_

def preprocess_ideology_ambigious(ambigious_labels):
    t_ideology = []

    for idx, a_label in enumerate(ambigious_labels):
        a_label = ambigious_labels[idx]
        if a_label == "0": #con
            t_ideology.append([0])
        else:#lib
            t_ideology.append([1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def preprocess_ideology_new(stance_labels, ideology_labels):
    t_ideology = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if i_label == 0: #con
            t_ideology.append([0])
        else:#lib
            t_ideology.append([1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def preprocess_ideology(stance_labels, ideology_labels):
    t_ideology = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if s_label == 1 and i_label == 0: #pro-con
            t_ideology.append([1,0])
        elif s_label == 1 and i_label == 1: #pro-lib
            t_ideology.append([1,1])
        elif s_label == 0 and i_label == 0: #agst-con
            t_ideology.append([0,0])
        else: #agst-lib
            t_ideology.append([0,1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def concanListStringsLonger(list1, list2):
    list3 = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " GIZEM " + list2[idx])
    return list3

def concanListStrings(list1, list2):
    list3 = []
    new_labels = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " " + list2[idx])
        #list3.append(list1[idx] + " " + list2[idx][-512:])
        #new_labels.append(labels[idx])
        #new_labels.append(labels[idx])
        
    return list3

def concanListStrings_sep(list1, list2):
    list3 = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " [SEP] " + str(list2[idx]))

    return list3

### Generate the datasets with the different fields.
def generate_datasets_ambigious(df, tokenizer):

    sentencesQuery= df.Q.values
    sentencesTitle = df.title.values
    sentencesCont = df.docCont.values

    labels = df.Ambigious.values
    
    #print(stances[0:10])

    sentencesQueryTitle = concanListStrings(sentencesQuery, sentencesTitle)
    sentencesQueryTitleCont = concanListStringsLonger(sentencesQueryTitle, sentencesCont)

    return sentencesQueryTitle, sentencesQueryTitleCont, labels


### Generate the datasets with the different fields.
def generate_datasets_ideology(df, tokenizer):

    sentencesQuery= df.Q.values
    sentencesQIdeology = df.q_ideology.values
    sentencesTitle = df.title.values
    sentencesCont = df.docCont.values

    stances = df.stance.values
    labels = df.ideology.values
    
    #print(stances[0:10])

    sentencesQueryTitle = concanListStrings(sentencesQuery, sentencesTitle)
    sentencesQueryTitleStance = concanListStrings(sentencesQueryTitle, stances)
    sentencesQueryTitleCont = concanListStringsLonger(sentencesQueryTitle, sentencesCont)
    sentencesQueryTitleStanceCont = concanListStringsLonger(sentencesQueryTitleStance, sentencesCont)

    return sentencesQueryTitle, sentencesQueryTitleCont, sentencesQueryTitleStance, sentencesQueryTitleStanceCont, stances, labels

def preprocessing_for_bert(tokenizer, docs, max_len, doc_stride):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []
    
    input_ids_last = []
    attention_masks_last = []
    
    content_input_ids = {}

    # For every sentence...
    for sent in docs:
        #print(sent)
        #print(sentences[0])
        #print(sentences[1])
        
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer.encode_plus (
            sent,  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=max_len,                  # Max length to truncate/pad
            #padding='longest',         # Pad sentence to max length
            pad_to_max_length = True,
            return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True      # Return attention mask
            )
        
        
        
        # Add the outputs to the lists
        input_ids.append(encoded_sent['input_ids'])
        attention_masks.append(encoded_sent['attention_mask'])
        
        # Print the original sentence.
        #print(' Original: ', sent)

        # Print the sentence split into tokens.
        #print('Tokenized: ', input_ids)
        
    # Convert the lists into tensors.
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    
    # Print sentence 0, now as a list of IDs.
    #print('Original: ', docs[0])
    #print('Token IDs:', input_ids[0])
    
    return input_ids, attention_masks

def transform_sequences_longer_ideology(tokenizer, docs, stanceLabels, ideologylabels, max_len, doc_stride):

    special_tokens_count = 2 #[CLS] and [SEP]
    # For every sentence...
    input_ids = []
    attention_masks = []
    stance_labels_Transformed = []
    ideology_labels_Transformed = []

    only_get_partial_text = False
    if doc_stride == 0:
        only_get_partial_text = True
        
    checked_doc_stride_thresh = doc_stride - special_tokens_count - 1
        
    allDocs_len = len(docs)
    for doc_id in range(0, allDocs_len):
        currDoc = docs[doc_id]
        currStanceLabel = stanceLabels[doc_id]
        currIdeologyLabel = ideologylabels[doc_id]
        
        my_idx = 0
        if "GIZEM" in currDoc:
            doc_splitted_tokens = currDoc.split(" ")
            my_idx = doc_splitted_tokens.index('GIZEM')
        else:
            doc_splitted_tokens = currDoc.split(" ")
        
        #query
        first_part_tokens = tokenizer.tokenize(' '.join(doc_splitted_tokens[0:my_idx]))
        myTokens = tokenizer.tokenize(' '.join(doc_splitted_tokens[my_idx+1:]))
        mytokens_maxlen = []

        first_part_len = len(first_part_tokens)
        cur_len = len(myTokens)
        #longer than the max-len, use doc-stride
        taken_len = max_len - first_part_len - special_tokens_count - 1
        
        if only_get_partial_text:
            mytokens_maxlen.append(first_part_tokens + myTokens[0:taken_len])
        else:
            checked_thresh = max_len - first_part_len - special_tokens_count
            if cur_len > checked_thresh:
                #get first part len
                while cur_len > checked_thresh:
                    partialTokens = first_part_tokens + myTokens[0:taken_len]
                    mytokens_maxlen.append(partialTokens)
                    del myTokens[0:checked_doc_stride_thresh]
                    cur_len = len(myTokens)
                if cur_len > 0:
                    mytokens_maxlen.append(first_part_tokens + myTokens)
            else:
                mytokens_maxlen.append(first_part_tokens + myTokens)

        if len(mytokens_maxlen) == 1:
            encoded_dict = tokenizer.encode_plus(
                        currDoc,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = max_len,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                   )
    
          # Add the encoded sentence to the list.    
            input_ids.append(encoded_dict['input_ids'])
    
          # And its attention mask (simply differentiates padding from non-padding).
            attention_masks.append(encoded_dict['attention_mask'])

            stance_labels_Transformed.append(currStanceLabel)
            ideology_labels_Transformed.append(currIdeologyLabel)
        else:
            for maxTokenList in mytokens_maxlen:
                if len(maxTokenList) > 510:
                    print(len(maxTokenList))
          #   (4) Map tokens to their IDs.
          #   (5) Pad or truncate the sentence to `max_length`
          #   (6) Create attention masks for [PAD] tokens.
                encoded_dict = tokenizer.encode_plus(
                    maxTokenList,                      # Sentence to encode.
                    add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                    max_length = max_len,           # Pad & truncate all sentences.
                    pad_to_max_length = True,
                    return_attention_mask = True,   # Construct attn. masks.
                    return_tensors = 'pt',     # Return pytorch tensors.
                    )
 
                input_ids.append(encoded_dict['input_ids'])
                attention_masks.append(encoded_dict['attention_mask'])
                stance_labels_Transformed.append(currStanceLabel)
                ideology_labels_Transformed.append(currIdeologyLabel)


    all_input_ids = torch.cat(input_ids, dim=0)
    all_input_mask = torch.cat(attention_masks, dim=0)
    stance_labels = torch.tensor(stance_labels_Transformed)
    ideology_labels = torch.tensor(ideology_labels_Transformed)
    
    print(all_input_ids.shape)

    return all_input_ids, all_input_mask, stance_labels, ideology_labels

import torch
from transformers import BertModel, RobertaModel
class IdeologyDetectionClass(torch.nn.Module):
    def __init__(self, modelUsed):
        super(IdeologyDetectionClass, self).__init__()
        input_size = 768
        hidden_size = 768
        mmd_size = 10
        dropout_prob = 0.5
        relatedness_size = 2
        classes_size = 2
        #agreement_size = 3
        
        self.input_pl = BertModel.from_pretrained(modelUsed) #input
        self.l1 = torch.nn.Linear(input_size, hidden_size)
        self.bn1_hidden = torch.nn.BatchNorm1d(hidden_size, momentum=0.05)
        self.dropout = torch.nn.Dropout(dropout_prob)
        
        self.stance = torch.nn.Linear(hidden_size, classes_size)
        self.output_prob = torch.nn.Softmax(dim = 1)

        #self.classifier = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        relatedness_size = 2
        classes_size = 1
        
        input_1 = self.input_pl(input_ids = input_ids, attention_mask = attention_mask)
        last_hidden_state_cls = input_1[0][:, 0, :]
        
        #hidden layer
        hidden_state = self.l1(last_hidden_state_cls)
        hidden_state_normalized = self.bn1_hidden(hidden_state)
        hidden_state_normalized = self.relu(hidden_state_normalized)
        hidden_layer= self.dropout(hidden_state_normalized)
        
        #mmd layer        
        #theta_d = self.theta_d(hidden_layer)
        ##theta_d_normalized = self.bn1_theta(theta_d)
        #theta_d_normalized = torch.nn.ReLU()(theta_d_normalized)
        #theta_d_layer= self.dropout(theta_d_normalized)

        #probability layer
        #relatedness_state = self.probability(hidden_layer)
        #relatedness_flat = self.dropout(relatedness_state)
        
        #relatedness_flat_reshaped = torch.reshape(relatedness_flat, (-1, relatedness_size))
        #P_relatedness = self.output_prob(relatedness_flat_reshaped)    
        
        #P_related = torch.reshape(P_relatedness[:, 0], (-1, 1))
        #P_unrelated = torch.reshape(P_relatedness[:, 1], (-1, 1))
        
        stance_state = self.stance(hidden_layer) #batch size x classes_size
        P_stance = self.output_prob(stance_state) 

        return P_stance

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = -1
        self.delta = delta

    def __call__(self, val_loss, val_acc, model_save_state, model_save_path):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, val_acc, model_save_state, model_save_path)
            self.val_acc_max = val_acc
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, val_acc, model_save_state, model_save_path)
            self.val_acc_max = val_acc
            self.counter = 0

            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, val_acc, model_save_state, model_save_path):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            print(f'Validation acc: ({self.val_acc_max:.6f} --> {val_acc:.6f}).  Saving model ...')
        #torch.save(model.module.state_dict(), 'checkpoint.pt')
        
        torch.save(model_save_state, model_save_path)
        
        
        #model.save_pretrained('model_save/')
        #tokenizer.save_pretrained('model_save/')
        # Good practice: save your training arguments together with the trained model
        #torch.save(model, './model_save/entire_model.pt')
        self.val_loss_min = val_loss

#from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
def prepare_for_training_ambigious(input_idsTrain, attention_masksTrain, ideology_labels_Train, input_idsVal, attention_masksVal, ideology_labels_Val, modelUsed, batch_size=16, epochs = 50, num_warmup_steps=0, learning_rate=5e-5):
    # Combine the training inputs into a TensorDataset.

    from transformers import BertForSequenceClassification, AdamW, BertConfig, RobertaConfig, AutoModelWithLMHead
    from transformers import DistilBertForSequenceClassification, RobertaForSequenceClassification
    
    from torch.utils.data import DataLoader, RandomSampler
    
    t_train_stance = preprocess_ideology_ambigious(ideology_labels_Train)
    
    datasetTrain = TensorDataset(input_idsTrain, attention_masksTrain, t_train_stance)

    # Combine the training inputs into a TensorDataset.
    t_val_stance  = preprocess_ideology_ambigious(ideology_labels_Val)
    
    
    datasetVal = TensorDataset(input_idsVal, attention_masksVal, t_val_stance)
    
    model = IdeologyDetectionClass(modelUsed)

    # Tell pytorch to run this model on the GPU.
    model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
    # I believe the 'W' stands for 'Weight Decay fix"
    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-3,
                  correct_bias=True
               )

    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=8
        )
    batch_size = batch_size


    from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

    # Number of training epochs. The BERT authors recommend between 2 and 4. 
    # We chose to run for 4, but we'll see later that this may be over-fitting the
    # training data.
    epochs = epochs

    # Total number of training steps is [number of batches] x [number of epochs]. 
    # (Note that this is not the same as the number of training samples).
    total_steps = len(train_dataloader) * epochs

    # Create the learning rate scheduler.
    schedulerOld = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps)
    
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps, num_cycles = 5)
    
    loss_fct = torch.nn.BCELoss()
    return model, datasetTrain, datasetVal, optimizer, schedulerOld

#from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
def prepare_for_trainingOld(input_idsTrain, attention_masksTrain, stance_labels_Train, ideology_labels_Train, input_idsVal, attention_masksVal, stance_labels_Val, ideology_labels_Val, modelUsed, batch_size=16, epochs = 50, num_warmup_steps=0, learning_rate=5e-5):
    # Combine the training inputs into a TensorDataset.

    from transformers import BertForSequenceClassification, AdamW, BertConfig, RobertaConfig, AutoModelWithLMHead
    from transformers import DistilBertForSequenceClassification, RobertaForSequenceClassification
    
    from torch.utils.data import DataLoader, RandomSampler
    
    t_train_stance, t_train_ideology, t_train_mmd_symbol, t_train_mmd_symbol_ = preprocess_ideology(stance_labels_Train, ideology_labels_Train)
    
    datasetTrain = TensorDataset(input_idsTrain, attention_masksTrain, t_train_stance, t_train_ideology, t_train_mmd_symbol, t_train_mmd_symbol_)

    # Combine the training inputs into a TensorDataset.
    t_val_stance, t_val_ideology, t_val_mmd_symbol, t_val_mmd_symbol_  = preprocess_ideology(stance_labels_Val, ideology_labels_Val)
    
    
    datasetVal = TensorDataset(input_idsVal, attention_masksVal, t_val_stance, t_val_ideology, t_val_mmd_symbol, t_val_mmd_symbol_)
    
    model = IdeologyDetectionClass(modelUsed)

    # Tell pytorch to run this model on the GPU.
    model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
    # I believe the 'W' stands for 'Weight Decay fix"
    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-3,
                  correct_bias=True
               )

    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=8
        )
    batch_size = batch_size


    from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

    # Number of training epochs. The BERT authors recommend between 2 and 4. 
    # We chose to run for 4, but we'll see later that this may be over-fitting the
    # training data.
    epochs = epochs

    # Total number of training steps is [number of batches] x [number of epochs]. 
    # (Note that this is not the same as the number of training samples).
    total_steps = len(train_dataloader) * epochs

    # Create the learning rate scheduler.
    schedulerOld = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps)
    
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps, num_cycles = 5)
    
    loss_fct = torch.nn.BCELoss()
    return model, datasetTrain, datasetVal, optimizer, schedulerOld, loss_fct

    return model, datasetTrain, datasetVal, optimizer, schedulerOld

def return_batches_datasets(datasetTrain, datasetVal, batch_size = 16):
    from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
        
    # Create the DataLoaders for our training and validation sets.
    # We'll take training samples in random order. 
    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=0
        )

    # For validation the order doesn't matter, so we'll just read them sequentially.
    validation_dataloader = DataLoader(
            datasetVal, # The validation samples.
            sampler = SequentialSampler(datasetVal), # Pull out batches sequentially.
            batch_size = batch_size, # Evaluate with this batch size.
            num_workers=0
        )
    
    
    #validation_dataloader = DataLoader(
    #        datasetVal, # The validation samples.
    #        sampler = SequentialSampler(datasetVal), # Pull out batches sequentially.
    #        batch_size = batch_size, # Evaluate with this batch size.
    #        num_workers=0, drop_last=True
    #)
    
    return train_dataloader, validation_dataloader

def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)


from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler
#from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
#import EarlyStopping
def train_stance_ideology_ambigious(train_nums, val_nums, train_nums_ideology, val_nums_ideology, model_save_path, 
                                    model, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    pro_val_num = val_nums[0]
    agst_val_num = val_nums[1]
    neut_val_num = val_nums[2] + 0.01
    notrel_val_num = val_nums[3]
    
    stance_all_num = pro_val_num + agst_val_num + neut_val_num + notrel_val_num
    
    con_val_num = 0.1
    lib_val_num = 0.1
    na_val_num = 0.1
    
    con_train_num = 0.1
    lib_train_num = 0.1
    na_train_num = 0.1
    
    #con_train_num = train_nums_ideology[0]
    #lib_train_num = train_nums_ideology[1]
    #na_train_num = train_nums_ideology[2]
    
    my_max_train_stance = max(pro_val_num, agst_val_num, neut_val_num, notrel_val_num)
    my_max_train = max(con_train_num, lib_train_num, na_train_num)
    
    #con_val_num = val_nums_ideology[0]
    #lib_val_num = val_nums_ideology[1]
    #na_val_num = val_nums_ideology[2]
    
    my_max = max(con_val_num, lib_val_num, na_val_num)
    
    ideology_all_num = con_val_num + lib_val_num + na_val_num
    
    writer = SummaryWriter()
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 4
    loss_fct_relatedness = torch.nn.BCEWithLogitsLoss()
    
    loss_fct_stance = torch.nn.CrossEntropyLoss()
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    
    alpha = 1.5
    beta = 1e-3
    theta = 0
    gamma = 0
    
    batch_size_max_once = 16

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
    
    
            
    if continue_train:    
        #'./model_save/fnc/model_emergentbert_epoch90_withoutsep_serp.t7'
        checkpoint = torch.load(model_save_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    
     #pos_weight=torch.FloatTensor ([28.36 / 0.5090]
    
     #pos_weight = torch.tensor([1.0, 1.0, 1.0])
     #pos_weight = pos_weight.to(device)
     #criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    weights_ideology = torch.tensor([my_max_train/con_train_num, my_max_train/lib_train_num, my_max_train/na_train_num]).to(device)   
    weights_stance = torch.tensor([my_max_train_stance/pro_val_num, my_max_train_stance/agst_val_num, my_max_train_stance/neut_val_num, my_max_train_stance/notrel_val_num]).to(device) 
    loss_fct_relatedness_weighted = torch.nn.BCEWithLogitsLoss(pos_weight = weights_stance)
    loss_fct_ideology_weighted = torch.nn.BCEWithLogitsLoss(pos_weight = weights_ideology)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(epoch_start, epoch_start + epochs):
        
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.

        #print("")
        #print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        #print('Training...')

        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0
        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        #train_size = len(train_dataloader)
        
        if batch_epoch_count % 500 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1

        #train_size = len(train_dataloader) / float(accumulation_steps)
        
        print("Batch Size: " + str(batch_size))
        print(float(accumulation_steps))
        
        #print("Learning rate: ", scheduler.get_last_lr())
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_relatedness = batch[2].to(device)
            b_labels = batch[3].to(device)
            b_mmd_symbol = batch[4].to(device)
            b_mmd_symbol_ = batch[5].to(device)
            b_existedstances = batch[6].to(device)
            b_ideologies = batch[7].to(device)
        
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            #mmd_loss, P_relatedness, P_stance, P_existedstance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
            P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
                
                
                
            #relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
            stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
            #existedstance_loss = loss_fct_relatedness(P_existedstance, b_existedstances.float())

            
            #loss = alpha * stance_loss + theta * existedstance_loss + beta * mmd_loss + relatedness_loss
            loss = stance_loss
            loss = loss / accumulation_steps 
            total_train_loss += loss.item()
                
            loss.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                    

                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                    #torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                    
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                #print("Learning Rate: ", optimizer.param_groups["lr"])
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()       
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(train_dataloader) * accumulation_steps
        #avg_train_loss = total_train_loss / len(train_dataloader)
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        #print("")
        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        print("  Training epoch took: {:}".format(training_time))
        
        
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        #print("")
        #print("Running Validation...")

        t1 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_true_eval_stance = 0
        total_true_eval_ideology = 0
        total_eval_loss = 0
        nb_eval_steps = 0
        
        agree_val_true = 0
        disagree_val_true = 0 
        discuss_val_true = 0 
        unrelated_val_true = 0
        
        con_val_true = 0
        lib_val_true = 0
        na_val_true = 0
        
        total_true = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_relatedness = batch[2].to(device)
            b_labels = batch[3].to(device)
            b_mmd_symbol = batch[4].to(device)
            b_mmd_symbol_ = batch[5].to(device)
            b_existedstances = batch[6].to(device)
            b_ideologies = batch[7].to(device)
        
            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.

                #mmd_loss, P_relatedness, P_stance, P_existedstance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
                P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)

                
                #CrossEntropy Loss
                #relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
                stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
                #existedstance_loss = loss_fct_relatedness(P_existedstance, b_existedstances.float())
                
                #loss_val = alpha * stance_loss + beta * mmd_loss + relatedness_loss
                loss_val = stance_loss
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                #P_relatedness = P_relatedness.to('cpu')
                #b_relatedness = b_relatedness.to('cpu')
                P_stance = P_stance.to('cpu')
                b_labels = b_labels.to('cpu')
                #P_existedstance = P_existedstance.to('cpu')
                #b_existedstances = b_existedstances.to('cpu')
                
                

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                #total_eval_accuracy += predict(P_relatedness, P_stance, b_labels)

                #acc_list = predict_classwise_stance_ideology(P_relatedness, P_stance, P_existedstance, b_labels)
                acc_list = predict_classwise_stance_ideology_bert(P_stance, b_labels)
                total_true_eval_stance += acc_list[0]
                ###
                agree_val_true += acc_list[1]
                disagree_val_true += acc_list[2]
                discuss_val_true += acc_list[3]
                unrelated_val_true += acc_list[4]
                
                total_true_eval_ideology += acc_list[5]
                con_val_true += acc_list[6]
                lib_val_true += acc_list[7]
                na_val_true += acc_list[8]
                
                predict_labels = acc_list[9]
                
                                
                ##print("Batch Next")
                #for idx in range(0, len(P_stance)):
                    
                    #print(P_stance[idx], b_labels[idx], acc_list[9][idx]) 

        # Report the final accuracy for this validation run.
        avg_val_accuracy_stance = total_true_eval_stance / stance_all_num
        avg_val_accuracy_ideology = total_true_eval_ideology / ideology_all_num
        print("Avg Val Accuracy Stance: {0:.6f}".format(avg_val_accuracy_stance))
        print("Avg Val Accuracy Ideology: {0:.6f}".format(avg_val_accuracy_ideology))
        print("Total True")
        print(total_true)
        print("*************")
        avg_val_agree_accuracy = agree_val_true / pro_val_num
        print("Avg Val Agree Accuracy: {0:.6f}".format(avg_val_agree_accuracy))
        avg_val_disagree_accuracy = disagree_val_true / agst_val_num
        print("Avg Val Disagree Accuracy: {0:.6f}".format(avg_val_disagree_accuracy))
        avg_val_discuss_accuracy = discuss_val_true / neut_val_num
        print("Avg Val Discuss Accuracy: {0:.6f}".format(avg_val_discuss_accuracy))
        avg_val_unrelated_accuracy = unrelated_val_true / notrel_val_num
        print("Avg Val Unrelated Accuracy: {0:.6f}".format(avg_val_unrelated_accuracy))
        
        relative_score = 0.25*avg_val_unrelated_accuracy + 0.75*(avg_val_agree_accuracy + avg_val_disagree_accuracy + avg_val_discuss_accuracy)/3
        
        print("*****************")
        print("Relative score: {0:.6f}".format(relative_score))
        print("*****************")
        print("-------------")
        avg_val_con_accuracy = con_val_true / con_val_num
        print("Avg Val Con Accuracy: {0:.6f}".format(avg_val_con_accuracy))
        avg_lib_accuracy = lib_val_true / lib_val_num
        print("Avg Val Lib Accuracy: {0:.6f}".format(avg_lib_accuracy))
        avg_na_discuss_accuracy = na_val_true / na_val_num
        print("Avg Val NA Accuracy: {0:.6f}".format(avg_na_discuss_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
        
        print("Total Validation loss", total_eval_loss)
        print("Len-validation loader", len(validation_dataloader))
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t1)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        print("  Validation took: {:}".format(validation_time))

        #avg_val_accuracy_ideology = 0
        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Stance Accur.': avg_val_accuracy_stance,
            'Valid. Ideology Accur.': avg_val_accuracy_ideology,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
        
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy_stance, avg_val_accuracy_ideology, model_save_state, model_save_path, model)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max_stance

    return training_stats, last_epoch, min_val_loss, max_val_acc

from torch.utils.tensorboard import SummaryWriter

#import EarlyStopping
def train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    loss_fct = torch.nn.BCELoss()
    create_determinism(seedVal)
    
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 2
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-2
    
    batch_size_max_once = 16    
    

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
          
    continue_train = False
    if continue_train:
        checkpoint = torch.load('models/2_a/')
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(0, epochs):
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.


        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        
        
        if batch_epoch_count % 200 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1
        
        train_size = len(train_dataloader) / accumulation_steps
        
        print("Batch Size: " + str(batch_size))
        print(accumulation_steps)
        
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)      
            ideology_loss = loss_fct(P_ideology, b_ideology.float())

            loss = ideology_loss

            #loss = torch.sum(loss, dim=0)


            # Accumulate the training loss over all of t0e batches so that we can
            # calculate the average loss at the end. `loss` is a Tensor containing a
            # single value; the `.item()` function just returns the Python value 
            # from the tensor.
            #loss_train = loss
            loss_train = loss / accumulation_steps
            # Calculate the average loss over all of the batches.
            
            #loss_length = torch.numel(loss_train)
            #fill_length = batch_size_max_once-loss_length
            #cat_tensor = torch.zeros(fill_length, device=device)

            #if loss_length < batch_size_max_once:
                #loss_train = torch.cat([loss_train, cat_tensor], dim=0)
                
            mini_batch_avg_loss += loss_train.item()
            
            # Perform a backward pass to calculate the gradients.
            loss_train.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()
                #total_train_loss = 0           
                
                total_train_loss += mini_batch_avg_loss
                mini_batch_avg_loss = 0
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        
        avg_train_loss = total_train_loss / train_size
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.


        t0 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_eval_accuracy = 0
        total_eval_loss = 0
        total_eval_stanceloss = 0
        total_eval_ideologicalloss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            

            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.
                
                
                P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)

                ideology_loss = loss_fct(P_ideology, b_ideology.float())

                loss_val = ideology_loss
                #loss_val = torch.sum(loss, dim=0).item()
                
                #logits = model(input_ids = b_input_ids,attention_mask=b_input_mask)
                
                #loss = loss_function(logits, b_labels)
            
                # Accumulate the validation loss.
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                P_ideology = P_ideology.to('cpu')
                b_ideology = b_ideology.to('cpu')

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                total_eval_accuracy += predict_binary(P_ideology, b_ideology)
        

        # Report the final accuracy for this validation run.
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("Avg Val Accuracy: {0:.6f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        #print("  Validation took: {:}".format(validation_time))

        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
    
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy, model_save_state, model_save_path)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max

    return training_stats, last_epoch, min_val_loss, max_val_acc

from torch.utils.tensorboard import SummaryWriter

#import EarlyStopping
def train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    loss_fct = torch.nn.BCELoss()
    create_determinism(seedVal)
    
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 2
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-2
    
    batch_size_max_once = 16    
    

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
            
    if continue_train:
        checkpoint = torch.load('models/2_a/')
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(0, epochs):
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.


        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        
        
        if batch_epoch_count % 200 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1
        
        train_size = len(train_dataloader) / accumulation_steps
        
        print("Batch Size: " + str(batch_size))
        print(accumulation_steps)
        
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)      
            ideology_loss = loss_fct(P_ideology, b_ideology.float())

            loss = ideology_loss

            #loss = torch.sum(loss, dim=0)


            # Accumulate the training loss over all of t0e batches so that we can
            # calculate the average loss at the end. `loss` is a Tensor containing a
            # single value; the `.item()` function just returns the Python value 
            # from the tensor.
            #loss_train = loss
            loss_train = loss / accumulation_steps
            # Calculate the average loss over all of the batches.
            
            #loss_length = torch.numel(loss_train)
            #fill_length = batch_size_max_once-loss_length
            #cat_tensor = torch.zeros(fill_length, device=device)

            #if loss_length < batch_size_max_once:
                #loss_train = torch.cat([loss_train, cat_tensor], dim=0)
                
            mini_batch_avg_loss += loss_train.item()
            
            # Perform a backward pass to calculate the gradients.
            loss_train.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()
                #total_train_loss = 0           
                
                total_train_loss += mini_batch_avg_loss
                mini_batch_avg_loss = 0
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        
        avg_train_loss = total_train_loss / train_size
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.


        t0 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_eval_accuracy = 0
        total_eval_loss = 0
        total_eval_stanceloss = 0
        total_eval_ideologicalloss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            

            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.
                
                
                P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)

                ideology_loss = loss_fct(P_ideology, b_ideology.float())

                loss_val = ideology_loss
                #loss_val = torch.sum(loss, dim=0).item()
                
                #logits = model(input_ids = b_input_ids,attention_mask=b_input_mask)
                
                #loss = loss_function(logits, b_labels)
            
                # Accumulate the validation loss.
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                P_ideology = P_ideology.to('cpu')
                b_ideology = b_ideology.to('cpu')

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                total_eval_accuracy += predict_binary(P_ideology, b_ideology)
        

        # Report the final accuracy for this validation run.
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("Avg Val Accuracy: {0:.6f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        #print("  Validation took: {:}".format(validation_time))

        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
    
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy, model_save_state, model_save_path)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max

    return training_stats, last_epoch, min_val_loss, max_val_acc

def print_summary(training_stats):
    # Display floats with two decimal places.
    pd.set_option('precision', 4)
    
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', 500)

    # Create a DataFrame from our training statistics.
    df_stats = pd.DataFrame(data=training_stats)

    # Use the 'epoch' as the row index.
    df_stats = df_stats.set_index('epoch')

    # A hack to force the column headers to wrap.
    #df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])


    # Display the table.
    return df_stats

def plot_results(df_stats, last_epoch):
    # Use plot styling from seaborn.
    sns.set(style='darkgrid')

    # Increase the plot size and font size.
    sns.set(font_scale=1.5)
    plt.rcParams["figure.figsize"] = (12,6)
    
    plot1 = plt.figure(1)
    
    plt.plot(df_stats['Training Loss'], 'b-o', label="Training_Loss")
    plt.plot(df_stats['Valid. Loss'], 'g-o', label="Val_Loss")

    # Label the plot.
    plt.title("Training & Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    #plt.autoscale(enable=True, axis='x')

    x_ticks = []
    for currEpoch in range(1, last_epoch+1):
        x_ticks.append(currEpoch)
    #plt.xticks(x_ticks)
    plt.xticks(rotation=90)

    plt.show()

def run_wholeprocess_fnc(tokenizer, model_current, train_path, val_path, max_len, doc_stride, batch_size, num_warmup_steps, learning_rate, seedVal):

#--------------LOAD DATASETS--------------#

    model_save_path = './models/BERT_SERP/bert_titleonly_ambigious'  
    df = load_dataset_ambigious("./dataset/batches_cleaned/stance/FullDataset_16.09.2021.tsv")
    
    #trainpath = "./dataset/ideology/train_new.tsv"
    #valPath = "./dataset/ideology/val_new.tsv"
    #testPath = "./dataset/ideology/test_new.tsv"
    
    trainPer = 0.8
    valPer = 0.2
    testPer = 0.2
    
    df, dfVal, dfTest = sample_dataset_stance(df, seedVal)
    
    #create_new_splits_and_writethem_to_csvfiles
    #create_train_val_test_split(trainpath, valPath, testPath, testPer, valPer, seedVal)
    
    ##df = load_dataset(trainpath)
    #dfVal = load_dataset(valPath)
    #dfTest = load_dataset(testPath)
    
    # Report the number of sentences.
    print('Number of training sentences: {:,}'.format(df.shape[0]))
    print('Number of val sentences: {:,}'.format(dfVal.shape[0]))
    print('Number of test sentences: {:,}'.format(dfTest.shape[0]))

    sentencesQueryTitle_Train = []
    sentencesQueryTitleCont_Train = []
    stances_Train = []
    labels_Train = []
    
    #--------------DATASETS-------------#

    #sentencesQueryTitle_Train, sentencesQueryTitleCont_Train, sentencesQueryTitleStance_Train, sentencesQueryTitleStanceCont_Train, stances_Train, labels_Train = generate_datasets_ideology (df, tokenizer)
    sentencesQueryTitle_Train, sentencesQueryTitleCont_Train, labels_Train = generate_datasets_ambigious(df, tokenizer)
    
    
    
    sentencesQueryTitle_Val = []
    sentencesQueryTitleCont_Val = []
    stances_Val = []
    labels_Val = []

 
    #sentencesQueryTitle_Val, sentencesQueryTitleCont_Val, sentencesQueryTitleStance_Val, sentencesQueryTitleStanceCont_Val, stances_Val, labels_Val = generate_datasets_ideology (dfVal, tokenizer)
    sentencesQueryTitle_Val, sentencesQueryTitleCont_Val, labels_Val = generate_datasets_ambigious(dfVal, tokenizer)
    
    sentencesQueryTitle_Test = []
    sentencesQueryTitleCont_Test = []
    stances_Test = []
    labels_Test = []
   
    #sentencesQueryTitle_Test, sentencesQueryTitleCont_Test, sentencesQueryTitleStance_Test, sentencesQueryTitleStanceCont_Test, stances_Test, labels_Test = generate_datasets_ideology (dfTest, tokenizer)
    sentencesQueryTitle_Test, sentencesQueryTitleCont_Test, labels_Test = generate_datasets_ambigious(dfTest, tokenizer)
    
    print(sentencesQueryTitle_Train[0])

    #--------------DATASETS-------------#

    all_input_ids_Train, all_input_masks_Train  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Train, max_len, doc_stride)
    all_input_ids_Val, all_input_masks_Val  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Val, max_len, doc_stride)
    all_input_ids_Test, all_input_masks_Test  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Test, max_len, doc_stride)
    
    #all_input_ids_Train, all_input_masks_Train, stance_labels_Train, ideology_labels_Train = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Train, stances_Train, labels_Train, max_len, doc_stride) #train
    #all_input_ids_Val, all_input_masks_Val, stance_labels_Val, ideology_labels_Val = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Val, stances_Val, labels_Val, max_len, doc_stride) #val
    #all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Test, stances_Test, labels_Test, max_len, doc_stride) #test

    model, datasetTrain, datasetVal, optimizer, scheduler = prepare_for_training_ambigious(all_input_ids_Train, all_input_masks_Train, labels_Train, all_input_ids_Val,
                                                                                                               all_input_masks_Val, labels_Val, model_current, batch_size, epochs, num_warmup_steps, learning_rate)    
    training_stats, last_epoch, min_val_loss, max_val_acc = train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer,
                                                                          scheduler, patience, verbose, delta, seedVal)
    
    
    avg_test_loss, avg_test_acc = run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, labels_Test, batch_size)
    df_stats = print_summary(training_stats)
    plot_results(df_stats, last_epoch)
    
    #--------------TRAINING-------------#
    batch_size_cuda = 16
    if batch_size < 16:
        batch_size_cuda = batch_size
        
    num_iterations = 5
    total_val_loss = 0.0
    total_val_acc = 0.0
    total_test_loss = 0.0
    total_test_acc = 0.0
    
    for i in range(0, num_iterations):
        
        value = randint(0, 100)
        seedVal = value
        print("******************")
        print("This is the iteration " + str(i))
        
        model_save_path = "model_save/ideology/model_news2a_qtitle.t7" + str(i)

        model, datasetTrain, datasetVal, optimizer, scheduler = prepare_for_training(all_input_ids_Train, all_input_masks_Train, stance_labels_Train, ideology_labels_Train, all_input_ids_Val,
                                                                                                               all_input_masks_Val, stance_labels_Val, ideology_labels_Val, model_current, batch_size_cuda, epochs, num_warmup_steps, learning_rate)    
        training_stats, last_epoch, min_val_loss, max_val_acc = train_stance (model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer,
                                                                          scheduler, patience, verbose, delta, seedVal)

        avg_test_loss, avg_test_acc = run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size)
        df_stats = print_summary(training_stats)
        plot_results(df_stats, last_epoch)
        
        total_val_loss += min_val_loss
        total_val_acc += max_val_acc
        total_test_loss += avg_test_loss
        total_test_acc += avg_test_acc

        print('Min Val Loss: ' + str(min_val_loss))
        print('Max Val Acc: ' + str(max_val_acc))
        print('Test Loss: ' + str(avg_test_loss))
        print('Test Acc: ' + str(avg_test_acc))
        
        
    print("******************")
    print('Avg Min Val Loss: ' + str(total_val_loss/num_iterations))
    print('Avg Max Val Acc: ' + str(total_val_acc/num_iterations))
    print('Avg Test Loss: ' + str(total_test_loss/num_iterations))
    print('Avg Test Acc: ' + str(total_test_acc/num_iterations))
    
    
    #model_to_save.save_pretrained('model_save')
    #tokenizer.save_pretrained('model_save')

    # Good practice: save your training arguments together with the trained model
    #torch.save(args, os.path.join('model_save', 'training_args.bin'))
    #model_args = str(max_len) + '_' + str(doc_stride) + '_' + str(batch_size) + "_" + str(learning_rate) + "_warmup" + str(num_warmup_steps) + "_seedVal" + str(seedVal)
    #model_path = model_save_path + '/model_' + model_args
    #model.save_pretrained(model_save_path)
    #torch.save(model.state_dict(), model_path)

def create_determinism(seedVal):
    import os
    torch.manual_seed(seedVal)
    torch.cuda.manual_seed_all(seedVal)  
    torch.cuda.manual_seed(seedVal)
    np.random.seed(seedVal)
    random.seed(seedVal)
    #os.environ['PYTHONHASHSEED'] = str(seedVal)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

    return avg_test_loss, avg_test_accuracy

from torch.utils.data import DataLoader, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig

def run_test_stance(model_savepath, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size = 16):
    #loss_fct = torch.nn.BCELoss()
    loss_fct_relatedness = torch.nn.BCEWithLogitsLoss()
    
    t_test_relatedness, t_test_stance, t_test_mmd_symbol, t_test_mmd_symbol_ = preprocess_fnc(stance_labels_Test)
    # Create the DataLoader.
    prediction_data = TensorDataset(all_input_ids_Test, all_input_masks_Test, t_test_relatedness, t_test_stance, t_test_mmd_symbol, t_test_mmd_symbol_)
    prediction_sampler = SequentialSampler(prediction_data)
    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size, num_workers=0)
    
    model_current = 'bert-base-uncased'
    tokenizer = load_tokenizer(model_current)
        
    model = StanceDetectionClass(model_current)
    checkpoint = torch.load(model_savepath)
    model.load_state_dict(checkpoint['state_dict'])    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-5,
                  correct_bias=True
    )
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    
    #model.cuda()
    # Put model in evaluation mode
    model.eval()

    # Tracking variables
    total_test_loss = 0.0
    
    total_test_accuracy = 0.0
    predictions , true_labels = [], []
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-3
    # Predict 
    for batch in prediction_dataloader:
      #Add batch to GPU
        
        #batch = tuple(t.to(device) for t in batch)
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_relatedness = batch[2].to(device)
        b_labels = batch[3].to(device)
        b_mmd_symbol = batch[4].to(device)
        b_mmd_symbol_ = batch[5].to(device)
  
        # Telling the model not to compute or store gradients, saving memory and 
        # speeding up prediction
        with torch.no_grad():         
            # Forward pass, calculate logit predictions
            
            n1 = torch.sum(b_mmd_symbol, dim=0)
            n2 = torch.sum(b_mmd_symbol_, dim=0)
        
            aa = torch.reshape(b_mmd_symbol, (-1,1))
            bb = torch.reshape(b_mmd_symbol_, (-1,1))
            
            theta_d_layer, P_relatedness, P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask)
                
            if n1 == 0:
                d1 = torch.zeros(batch_size, 1, device = device)
            else:
                d1 = torch.div(torch.sum(theta_d_layer*aa, dim=1), n1)
                
            if n2 == 0:
                d2 = torch.zeros(batch_size, 1, device = device)
            else:
                d2 = torch.div(torch.sum(theta_d_layer*bb, dim=1), n2)
                    
                    
            mmd_loss = torch.sum(d1 - d2)
                
            
            relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
            stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
                
    
            loss_test = relatedness_loss + alpha * stance_loss - beta * mmd_loss
            total_test_loss += loss_test.item()
            
            # Move logits and labels to CPU
            P_relatedness = P_relatedness.to('cpu')
            b_relatedness = b_relatedness.to('cpu')
            P_stance = P_stance.to('cpu')
            b_labels = b_labels.to('cpu')

            total_test_accuracy += predict(P_relatedness, P_stance, b_labels)

    # Report the final accuracy for this validation run.
    avg_test_loss = total_test_loss / len(prediction_dataloader)
    avg_test_accuracy = total_test_accuracy / len(prediction_dataloader)

    return avg_test_loss, avg_test_accuracy

from torch.utils.data import DataLoader, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig

def run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size = 16):

    t_ideology_labels_test = preprocess_ideology_new(stance_labels_Test, ideology_labels_Test)
    # Create the DataLoader.
    prediction_data = TensorDataset(all_input_ids_Test, all_input_masks_Test, t_ideology_labels_test)
    prediction_sampler = SequentialSampler(prediction_data)
    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)

    loss_fct = torch.nn.BCELoss()
    
    model_current = 'bert-base-uncased'
    tokenizer = load_tokenizer(model_current)
        
    model = IdeologyDetectionClass(model_current)
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['state_dict'])    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-5,
                  correct_bias=True
    )
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch']
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    

    model.to(device)
    # Put model in evaluation mode
    model.eval()

    # Tracking variables
    total_test_loss = 0.0
    
    total_test_accuracy = 0.0
    predictions , true_labels = [], []
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-3
    # Predict 
    for batch in prediction_dataloader:
      #Add batch to GPU

        #batch = tuple(t.to(device) for t in batch)
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_ideologylabels = batch[2].to(device)
  
        # Telling the model not to compute or store gradients, saving memory and 
        # speeding up prediction
        with torch.no_grad():         
            
            
            # Forward pass, calculate logit predictions
            P_ideology = model(b_input_ids, attention_mask=b_input_mask)

            ideology_loss = loss_fct(P_ideology, b_ideologylabels.float())

            loss = ideology_loss

            #logits = outputs[0]

            # Move logits and labels to CPU
            P_ideology = P_ideology.detach().cpu()
            t_ideology_labels_test = b_ideologylabels.to('cpu')

            # Calculate the accuracy for this batch of test sentences, and
            # accumulate it over all batches.
            total_test_loss += loss.item()
            total_test_accuracy += predict_binary(P_ideology, t_ideology_labels_test)
        

    # Report the final accuracy for this validation run.
    avg_test_loss = total_test_loss / len(prediction_dataloader)
    
    avg_test_accuracy = total_test_accuracy / len(prediction_dataloader)
  
            # Store predictions and true labels
            #predictions.append(logits)
            #true_labels.append(label_ids)
    #print('Test Accuracy', avg_test_accuracy)

    return avg_test_loss, avg_test_accuracy

### import os
import string
import tensorflow as tf
import torch
import pandas as pd
import numpy as np
from random import randint
import random
import time
import datetime
from transformers import AutoModel
from transformers import DistilBertModel
from torch.utils.data import TensorDataset, random_split
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

#model = "bert-base-uncased"
train_path = './dataset/fnc/train'
val_path = './dataset/fnc/test'

model_save_path = './model_save/'
model_name = 'querytitle_model_base_9'

#device = run_utils()
        
model_base = 'bert-base-uncased'
model_roberta = "roberta-base"
model_finetuned = './models/2_a/'
model_finetuned2 = './models/2_b/'
model_tiny_bert = './models/tiny_bert/'
    
model_current = model_base
tokenizer = load_tokenizer(model_current)

max_len = 512
doc_stride = 128

batch_size = 16
epochs = 30
num_warmup_steps = 10
learning_rate = 2e-6

##-----Early Stopping
patience = 4000
verbose = True
delta = 0.000001
seedVal = 20

train_flag = True

if train_flag:
    run_wholeprocess_fnc(tokenizer, model_current, train_path, val_path, max_len, doc_stride, batch_size, num_warmup_steps, learning_rate,seedVal)  
else:
    avg_test_loss, avg_test_acc = run_onlytest_ideology(tokenizer, model_save_path + model_name, torch.nn.BCELoss(), device)
    print(avg_test_loss, avg_test_acc)






def run_utils():
    # Get the GPU device name.
    device_name = tf.test.gpu_device_name()
    # The device name should look like the following:
    if device_name == '/device:GPU:0':
        print('Found GPU at: {}'.format(device_name))
    else:
        raise SystemError('GPU device not found')

    device = None
    # If there's a GPU available...
    if torch.cuda.is_available():    
        # Tell PyTorch to use the GPU.    
        device = torch.device("cuda")
        print('There are %d GPU(s) available.' % torch.cuda.device_count())
        print('We will use the GPU:', torch.cuda.get_device_name(0))
    # If not...
    else:
        print('No GPU available, using the CPU instead.')
        device = torch.device("cpu")

    return device


# Function to calculate the accuracy of our predictions vs labels
def predict(P_ideology, ideology_labels):
    predict_labels = torch.argmax(P_ideology, 1)
    target_labels = torch.argmax(ideology_labels, 1)
    

    true_predict_count = len((torch.eq(predict_labels, target_labels)).nonzero().flatten())
    accuracy = true_predict_count / len(predict_labels)
    
    return accuracy


# Function to calculate the accuracy of our predictions vs labels
def predict_binary(P_ideology, ideology_labels):
    predict_labels = np.round(P_ideology)
    predict_labels = predict_labels.int()

    true_predict_count = (torch.eq(predict_labels, ideology_labels)).sum()
    true_predict_count = true_predict_count.numpy()
    #print(true_predict_count)
    
    accuracy = true_predict_count / len(predict_labels)
    
    return accuracy


def find_maxLen_doc(data, tokenizer):
    max_len = 0
    # For every sentence...
    for sent in data:
        # Tokenize the text and add `[CLS]` and `[SEP]` tokens.
        input_ids = tokenizer.encode(sent, add_special_tokens=True)
        # Update the maximum sentence length.
        max_len = max(max_len, len(input_ids))

    print('Max sentence length: ', max_len)

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

def load_tokenizer(model):
    tokenizer = None
    from transformers import AutoTokenizer, DistilBertTokenizer, BertTokenizer, RobertaTokenizer, AutoModelWithLMHead
    tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=True)
    #tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    return tokenizer

def load_dataset(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header=0, names=['qID', 'q_ideology', 'ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])       

    df['docCont'] = df['docCont'].str.lower()
    #df['topic'] = df['topic'].str.lower()
    df['Q'] = df['Q'].str.lower()
    df['title'] = df['title'].str.lower()
    
    #df.insert(0, "stanceStr", df['stance'], True)
    #df["stanceStr"] = df["stanceStr"].replace({1: "Pro", 0: "Agst"})
    
    print("Train")
    print ("Con", df[df.ideology == 0].shape[0])
    print ("Lib", df[df.ideology == 1].shape[0])
    print ("Pro", df[df.stance == 1].shape[0])
    print ("Against", df[df.stance == 0].shape[0])


    return df

def load_dataset_ambigious(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header=0, names=['qID', 'docID', 'stance', 'Ambigious', 'ideology', 'docCont', 'Q', 'title'])       

    df['docCont'] = df['docCont'].str.lower()
    #df['topic'] = df['topic'].str.lower()
    df['Q'] = df['Q'].str.lower()
    df['title'] = df['title'].str.lower()
    
    df = df.astype(str)
    
    #df.insert(0, "stanceStr", df['stance'], True)
    #df["stanceStr"] = df["stanceStr"].replace({1: "Pro", 0: "Agst"})
    
    print("Train")
    print ("Ambigious", df[df.Ambigious == "1"].shape[0])
    print ("Non-ambigious", df[df.Ambigious == "0"].shape[0])
    
    print(df.dtypes)

    return df

def load_dataset_first(path):
    # Load the dataset into a pandas dataframe.
    df = pd.read_csv(path, delimiter='\t', header = 0, names=['qID', 'ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])

    df_q = pd.read_csv(path.replace('final.tsv', 'final_onlyqID.tsv'), delimiter='\t', header=0, names=['qID', 'orientation'])
    df_q["orientation"] = df_q["orientation"].replace({-1: "Lib", 1: "Con"})
           
    df = df.drop('qID', axis=1)
    df.insert(0, "qID", df_q['qID'], True)
    df.insert(1, "q_ideology", df_q['orientation'], True)
  
    df["stance"] = df["stance"].replace({"-1": 0, "1": 1, "Pro": 1, "Agst": 0})
    
    return df

def sample_dataset_stance(df, seedVal):
    #create_determinism(seedVal)
    
    df_A = df[df['Ambigious'] == "1"]
    df_N = df[df['Ambigious'] == "0"]
    
    
    df_new = df_A.append(df_N, ignore_index = True)

    y_copy = df_new['Ambigious'].copy(deep=True)
    X_copy = df_new.drop('Ambigious', axis=1).copy(deep=True)
    
    X = pd.DataFrame (columns=['qID', 'docID', 'stance', 'ideology', 'docCont', 'Q', 'title'])
    y = pd.DataFrame (columns=['Ambigious'])
    
    X = X_copy
    y = y_copy
    
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True)
    
    print(len(X_train))
    print(len(y_train))
    print(len(X_test))
    print(len(y_test))
    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, shuffle=True)
    
    X_train.insert(2, "Ambigious", y_train.values) 
    X_val.insert(2, "Ambigious", y_val.values) 
    X_test.insert(2, "Ambigious", y_test.values)
    
    
    df_A = X_train[X_train['Ambigious'] == "1"]
    df_N = X_train[X_train['Ambigious'] == "0"]
    
    
    print("****Train****")
    print("Ambigious", df_A.shape[0])
    print("Not Ambigious", df_N.shape[0])
    
    
    df_A = X_test[X_test['Ambigious'] == "1"]
    df_N = X_test[X_test['Ambigious'] == "0"]
    
    print("****Test****")
    print("Ambigious", df_A.shape[0])
    print("Not Ambigious", df_N.shape[0])
    
    X_train.to_csv('./dataset/batches_cleaned/stance/train_serp.tsv', sep='\t', index=False)
    X_val.to_csv('./dataset/batches_cleaned/stance/val_serp.tsv', sep='\t', index=False)
    X_test.to_csv('./dataset/batches_cleaned/stance/test_serp.tsv', sep='\t', index=False)

    return X_train, X_val, X_test

def merge_datasets(df, dfVal, dfTest):
    from numpy import nan
    df = df.append(dfVal, ignore_index = True)
    df = df.append(dfTest, ignore_index = True)
    
    df.replace("", nan, inplace=True)
    df.replace(" ", nan, inplace=True)
    df.dropna(axis=0, how='any', thresh=None, subset=None, inplace=True)
    
    dfLabel = df['ideology'].copy(deep=True)
    df = df.drop('ideology', axis=1).copy(deep=True)
    
    return df, dfLabel

def preprocess_dataset_new_ideology_latest(df_new, testPer, seedVal):
    from sklearn.model_selection import train_test_split
    from pandas import DataFrame
    
    create_determinism(seedVal)
    
    #print("New dataset")
    #print(df_new['stance'].value_counts())

    y = df_new['ideology'].copy(deep=True)
    X = df_new.drop('ideology', axis=1).copy(deep=True)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testPer, shuffle=True, stratify=y)
            
    return X_train, y_train, X_test, y_test

def preprocess_dataset_new(df, dfLabels, testPer, seedVal):
    from sklearn.model_selection import train_test_split
    from pandas import DataFrame
    
    create_determinism(seedVal)

    df.insert(2, "ideology", dfLabels.values) 
    df = df.sort_values(by='Q')
     
    one_q_instances = []
    all_q_instances = {}
    
    curr_q = df.Q.values[0]
    for index, inst in df.iterrows():
        if curr_q == inst['Q']:
            one_q_instances.append(inst.values)
        else:
            all_q_instances[curr_q] = one_q_instances
            one_q_instances = []
            curr_q = inst['Q']
            one_q_instances.append(inst.values)
            
    
    X_train_allqueries = {}
    X_test_allqueries = {}
    y_train_allqueries = {}
    y_test_allqueries = {}
            
    for query in all_q_instances:
        this_query_instances = all_q_instances[query]
        
        df = DataFrame (this_query_instances, columns=['qID', 'q_ideology', 'stance', 'ideology', 'docCont', 'topic', 'Q', 'title'])
        
        y = df['ideology'].copy(deep=True)
        X = df.drop('ideology', axis=1).copy(deep=True)
    
        X_train_allqueries[query] = []
        y_train_allqueries[query] = []
        X_test_allqueries[query] = []
        y_test_allqueries[query] = []
        
        if len(X.index) > 1:
        
            con_count = len(df[df['ideology'] == 0])
            lib_count = len(df[df['ideology'] == 1])

            if con_count < 2 or lib_count < 2:
                X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True)
            else:
                if((con_count + lib_count)*testPer > 1):
                    X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True, stratify=y)
                else:
                    X_train_allqueries[query], X_test_allqueries[query], y_train_allqueries[query], y_test_allqueries[query] = train_test_split(X, y, test_size=testPer, shuffle=True)
        else:
            X_train_allqueries[query] = X
            y_train_allqueries[query] = y
    
    X_train = pd.DataFrame(columns=['qID', 'q_ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])
    y_train = pd.DataFrame(columns=['ideology'])
    
    X_test = pd.DataFrame(columns=['qID', 'q_ideology', 'stance', 'docCont', 'topic', 'Q', 'title'])
    y_test = pd.DataFrame(columns=['ideology'])
    
    for query in X_train_allqueries:
        X_train = X_train.append(pd.DataFrame(X_train_allqueries[query]), ignore_index = True)
        y_train = y_train.append(pd.DataFrame(y_train_allqueries[query]), ignore_index = True)
    
    for query in X_test_allqueries:
        X_test = X_test.append(pd.DataFrame(X_test_allqueries[query]), ignore_index = True)
        y_test = y_test.append(pd.DataFrame(y_test_allqueries[query]), ignore_index = True)
    
    return X_train, y_train, X_test, y_test

def create_train_val_test_split(trainpath, valpath, testpath, testPer, valPer, seedVal):
    
    df = load_dataset("./dataset/ideology/train_samples.tsv")
    dfVal = load_dataset("./dataset/ideology/val_samples.tsv")
    dfTest = load_dataset("./dataset/ideology/test_samples.tsv")

    dfComp, dfCompLabel = merge_datasets(df, dfVal, dfTest)
    dfComp.insert(2, "ideology", dfCompLabel.values)
    
    df, dfLabel, dfTest, dfTestLabel = preprocess_dataset_new_ideology_latest(dfComp, 0.2, seedVal)
    df.insert(2, "ideology", dfLabel.values)
    
    df, dfLabel, dfVal, dfValLabel = preprocess_dataset_new_ideology_latest(df, 0.2, seedVal)
    
    df.insert(2, "ideology", dfLabel.values)
    dfVal.insert(2, "ideology", dfValLabel.values) 
    dfTest.insert(2, "ideology", dfTestLabel.values)
    
    df.to_csv('train_new.tsv', sep='\t', index=False)
    dfVal.to_csv('val_new.tsv', sep='\t', index=False)
    dfTest.to_csv('test_new.tsv', sep='\t', index=False)

def preprocess_ideologyOld(stance_labels, ideology_labels):
    t_stance = []
    t_ideology = []
    
    t_mmd_symbol = []
    t_mmd_symbol_ = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if s_label == 1 and i_label == 0: #pro-con
            t_stance.append([1,0])
            t_ideology.append([0,1])
            t_mmd_symbol.append(1)
            t_mmd_symbol_.append(0)
        elif s_label == 1 and i_label == 1: #pro-lib
            t_stance.append([1,0]) 
            t_ideology.append([1,0])
            t_mmd_symbol.append(1)
            t_mmd_symbol_.append(1)
        elif s_label == 0 and i_label == 0: #agst-con
            t_stance.append([0,1])
            t_ideology.append([0,1])
            t_mmd_symbol.append(0)
            t_mmd_symbol_.append(0)
        else: #agst-lib
            t_stance.append([0,1])
            t_ideology.append([1,0])
            t_mmd_symbol.append(0)
            t_mmd_symbol_.append(1)
            
    
    t_stance = torch.as_tensor(t_stance, dtype=torch.int32)
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    t_mmd_symbol  = torch.as_tensor(t_mmd_symbol, dtype=torch.float32)
    t_mmd_symbol_ = torch.as_tensor(t_mmd_symbol_, dtype=torch.float32)
    
    return t_stance, t_ideology, t_mmd_symbol, t_mmd_symbol_

def preprocess_ideology_ambigious(ambigious_labels):
    t_ideology = []

    for idx, a_label in enumerate(ambigious_labels):
        a_label = ambigious_labels[idx]
        if a_label == "0": #con
            t_ideology.append([0])
        else:#lib
            t_ideology.append([1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def preprocess_ideology_new(stance_labels, ideology_labels):
    t_ideology = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if i_label == 0: #con
            t_ideology.append([0])
        else:#lib
            t_ideology.append([1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def preprocess_ideology(stance_labels, ideology_labels):
    t_ideology = []

    for idx, s_label in enumerate(stance_labels):
        i_label = ideology_labels[idx]
        if s_label == 1 and i_label == 0: #pro-con
            t_ideology.append([1,0])
        elif s_label == 1 and i_label == 1: #pro-lib
            t_ideology.append([1,1])
        elif s_label == 0 and i_label == 0: #agst-con
            t_ideology.append([0,0])
        else: #agst-lib
            t_ideology.append([0,1])
            
    t_ideology = torch.as_tensor(t_ideology, dtype=torch.int32)
    
    return t_ideology

def concanListStringsLonger(list1, list2):
    list3 = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " GIZEM " + list2[idx])
    return list3

def concanListStrings(list1, list2):
    list3 = []
    new_labels = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " " + list2[idx])
        #list3.append(list1[idx] + " " + list2[idx][-512:])
        #new_labels.append(labels[idx])
        #new_labels.append(labels[idx])
        
    return list3

def concanListStrings_sep(list1, list2):
    list3 = []
    myLen1 = len(list1)
    if myLen1 != len(list2):
        print("Length - error")
    for idx in range(0, myLen1):
        list3.append(list1[idx] + " [SEP] " + str(list2[idx]))

    return list3

### Generate the datasets with the different fields.
def generate_datasets_ambigious(df, tokenizer):

    sentencesQuery= df.Q.values
    sentencesTitle = df.title.values
    sentencesCont = df.docCont.values

    labels = df.Ambigious.values
    
    #print(stances[0:10])

    sentencesQueryTitle = concanListStrings(sentencesQuery, sentencesTitle)
    sentencesQueryTitleCont = concanListStringsLonger(sentencesQueryTitle, sentencesCont)

    return sentencesQueryTitle, sentencesQueryTitleCont, labels


### Generate the datasets with the different fields.
def generate_datasets_ideology(df, tokenizer):

    sentencesQuery= df.Q.values
    sentencesQIdeology = df.q_ideology.values
    sentencesTitle = df.title.values
    sentencesCont = df.docCont.values

    stances = df.stance.values
    labels = df.ideology.values
    
    #print(stances[0:10])

    sentencesQueryTitle = concanListStrings(sentencesQuery, sentencesTitle)
    sentencesQueryTitleStance = concanListStrings(sentencesQueryTitle, stances)
    sentencesQueryTitleCont = concanListStringsLonger(sentencesQueryTitle, sentencesCont)
    sentencesQueryTitleStanceCont = concanListStringsLonger(sentencesQueryTitleStance, sentencesCont)

    return sentencesQueryTitle, sentencesQueryTitleCont, sentencesQueryTitleStance, sentencesQueryTitleStanceCont, stances, labels

def preprocessing_for_bert(tokenizer, docs, max_len, doc_stride):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """
    # Create empty lists to store outputs
    input_ids = []
    attention_masks = []
    
    input_ids_last = []
    attention_masks_last = []
    
    content_input_ids = {}

    # For every sentence...
    for sent in docs:
        #print(sent)
        #print(sentences[0])
        #print(sentences[1])
        
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer.encode_plus (
            sent,  # Preprocess sentence
            add_special_tokens=True,        # Add `[CLS]` and `[SEP]`
            max_length=max_len,                  # Max length to truncate/pad
            #padding='longest',         # Pad sentence to max length
            pad_to_max_length = True,
            return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True      # Return attention mask
            )
        
        
        
        # Add the outputs to the lists
        input_ids.append(encoded_sent['input_ids'])
        attention_masks.append(encoded_sent['attention_mask'])
        
        # Print the original sentence.
        #print(' Original: ', sent)

        # Print the sentence split into tokens.
        #print('Tokenized: ', input_ids)
        
    # Convert the lists into tensors.
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    
    # Print sentence 0, now as a list of IDs.
    #print('Original: ', docs[0])
    #print('Token IDs:', input_ids[0])
    
    return input_ids, attention_masks

def transform_sequences_longer_ideology(tokenizer, docs, stanceLabels, ideologylabels, max_len, doc_stride):

    special_tokens_count = 2 #[CLS] and [SEP]
    # For every sentence...
    input_ids = []
    attention_masks = []
    stance_labels_Transformed = []
    ideology_labels_Transformed = []

    only_get_partial_text = False
    if doc_stride == 0:
        only_get_partial_text = True
        
    checked_doc_stride_thresh = doc_stride - special_tokens_count - 1
        
    allDocs_len = len(docs)
    for doc_id in range(0, allDocs_len):
        currDoc = docs[doc_id]
        currStanceLabel = stanceLabels[doc_id]
        currIdeologyLabel = ideologylabels[doc_id]
        
        my_idx = 0
        if "GIZEM" in currDoc:
            doc_splitted_tokens = currDoc.split(" ")
            my_idx = doc_splitted_tokens.index('GIZEM')
        else:
            doc_splitted_tokens = currDoc.split(" ")
        
        #query
        first_part_tokens = tokenizer.tokenize(' '.join(doc_splitted_tokens[0:my_idx]))
        myTokens = tokenizer.tokenize(' '.join(doc_splitted_tokens[my_idx+1:]))
        mytokens_maxlen = []

        first_part_len = len(first_part_tokens)
        cur_len = len(myTokens)
        #longer than the max-len, use doc-stride
        taken_len = max_len - first_part_len - special_tokens_count - 1
        
        if only_get_partial_text:
            mytokens_maxlen.append(first_part_tokens + myTokens[0:taken_len])
        else:
            checked_thresh = max_len - first_part_len - special_tokens_count
            if cur_len > checked_thresh:
                #get first part len
                while cur_len > checked_thresh:
                    partialTokens = first_part_tokens + myTokens[0:taken_len]
                    mytokens_maxlen.append(partialTokens)
                    del myTokens[0:checked_doc_stride_thresh]
                    cur_len = len(myTokens)
                if cur_len > 0:
                    mytokens_maxlen.append(first_part_tokens + myTokens)
            else:
                mytokens_maxlen.append(first_part_tokens + myTokens)

        if len(mytokens_maxlen) == 1:
            encoded_dict = tokenizer.encode_plus(
                        currDoc,                      # Sentence to encode.
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = max_len,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                   )
    
          # Add the encoded sentence to the list.    
            input_ids.append(encoded_dict['input_ids'])
    
          # And its attention mask (simply differentiates padding from non-padding).
            attention_masks.append(encoded_dict['attention_mask'])

            stance_labels_Transformed.append(currStanceLabel)
            ideology_labels_Transformed.append(currIdeologyLabel)
        else:
            for maxTokenList in mytokens_maxlen:
                if len(maxTokenList) > 510:
                    print(len(maxTokenList))
          #   (4) Map tokens to their IDs.
          #   (5) Pad or truncate the sentence to `max_length`
          #   (6) Create attention masks for [PAD] tokens.
                encoded_dict = tokenizer.encode_plus(
                    maxTokenList,                      # Sentence to encode.
                    add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                    max_length = max_len,           # Pad & truncate all sentences.
                    pad_to_max_length = True,
                    return_attention_mask = True,   # Construct attn. masks.
                    return_tensors = 'pt',     # Return pytorch tensors.
                    )
 
                input_ids.append(encoded_dict['input_ids'])
                attention_masks.append(encoded_dict['attention_mask'])
                stance_labels_Transformed.append(currStanceLabel)
                ideology_labels_Transformed.append(currIdeologyLabel)


    all_input_ids = torch.cat(input_ids, dim=0)
    all_input_mask = torch.cat(attention_masks, dim=0)
    stance_labels = torch.tensor(stance_labels_Transformed)
    ideology_labels = torch.tensor(ideology_labels_Transformed)
    
    print(all_input_ids.shape)

    return all_input_ids, all_input_mask, stance_labels, ideology_labels

import torch
from transformers import BertModel, RobertaModel
class IdeologyDetectionClass(torch.nn.Module):
    def __init__(self, modelUsed):
        super(IdeologyDetectionClass, self).__init__()
        input_size = 768
        hidden_size = 768
        mmd_size = 10
        dropout_prob = 0.5
        relatedness_size = 2
        classes_size = 2
        #agreement_size = 3
        
        self.input_pl = BertModel.from_pretrained(modelUsed) #input
        self.l1 = torch.nn.Linear(input_size, hidden_size)
        self.bn1_hidden = torch.nn.BatchNorm1d(hidden_size, momentum=0.05)
        self.dropout = torch.nn.Dropout(dropout_prob)
        
        self.stance = torch.nn.Linear(hidden_size, classes_size)
        self.output_prob = torch.nn.Softmax(dim = 1)

        #self.classifier = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask):
        relatedness_size = 2
        classes_size = 1
        
        input_1 = self.input_pl(input_ids = input_ids, attention_mask = attention_mask)
        last_hidden_state_cls = input_1[0][:, 0, :]
        
        #hidden layer
        hidden_state = self.l1(last_hidden_state_cls)
        hidden_state_normalized = self.bn1_hidden(hidden_state)
        hidden_state_normalized = self.relu(hidden_state_normalized)
        hidden_layer= self.dropout(hidden_state_normalized)
        
        #mmd layer        
        #theta_d = self.theta_d(hidden_layer)
        ##theta_d_normalized = self.bn1_theta(theta_d)
        #theta_d_normalized = torch.nn.ReLU()(theta_d_normalized)
        #theta_d_layer= self.dropout(theta_d_normalized)

        #probability layer
        #relatedness_state = self.probability(hidden_layer)
        #relatedness_flat = self.dropout(relatedness_state)
        
        #relatedness_flat_reshaped = torch.reshape(relatedness_flat, (-1, relatedness_size))
        #P_relatedness = self.output_prob(relatedness_flat_reshaped)    
        
        #P_related = torch.reshape(P_relatedness[:, 0], (-1, 1))
        #P_unrelated = torch.reshape(P_relatedness[:, 1], (-1, 1))
        
        stance_state = self.stance(hidden_layer) #batch size x classes_size
        P_stance = self.output_prob(stance_state) 

        return P_stance

import numpy as np
import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.val_acc_max = -1
        self.delta = delta

    def __call__(self, val_loss, val_acc, model_save_state, model_save_path):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, val_acc, model_save_state, model_save_path)
            self.val_acc_max = val_acc
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, val_acc, model_save_state, model_save_path)
            self.val_acc_max = val_acc
            self.counter = 0

            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(self, val_loss, val_acc, model_save_state, model_save_path):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
            print(f'Validation acc: ({self.val_acc_max:.6f} --> {val_acc:.6f}).  Saving model ...')
        #torch.save(model.module.state_dict(), 'checkpoint.pt')
        
        torch.save(model_save_state, model_save_path)
        
        
        #model.save_pretrained('model_save/')
        #tokenizer.save_pretrained('model_save/')
        # Good practice: save your training arguments together with the trained model
        #torch.save(model, './model_save/entire_model.pt')
        self.val_loss_min = val_loss

#from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
def prepare_for_training_ambigious(input_idsTrain, attention_masksTrain, ideology_labels_Train, input_idsVal, attention_masksVal, ideology_labels_Val, modelUsed, batch_size=16, epochs = 50, num_warmup_steps=0, learning_rate=5e-5):
    # Combine the training inputs into a TensorDataset.

    from transformers import BertForSequenceClassification, AdamW, BertConfig, RobertaConfig, AutoModelWithLMHead
    from transformers import DistilBertForSequenceClassification, RobertaForSequenceClassification
    
    from torch.utils.data import DataLoader, RandomSampler
    
    t_train_stance = preprocess_ideology_ambigious(ideology_labels_Train)
    
    datasetTrain = TensorDataset(input_idsTrain, attention_masksTrain, t_train_stance)

    # Combine the training inputs into a TensorDataset.
    t_val_stance  = preprocess_ideology_ambigious(ideology_labels_Val)
    
    
    datasetVal = TensorDataset(input_idsVal, attention_masksVal, t_val_stance)
    
    model = IdeologyDetectionClass(modelUsed)

    # Tell pytorch to run this model on the GPU.
    model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
    # I believe the 'W' stands for 'Weight Decay fix"
    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-3,
                  correct_bias=True
               )

    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=8
        )
    batch_size = batch_size


    from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

    # Number of training epochs. The BERT authors recommend between 2 and 4. 
    # We chose to run for 4, but we'll see later that this may be over-fitting the
    # training data.
    epochs = epochs

    # Total number of training steps is [number of batches] x [number of epochs]. 
    # (Note that this is not the same as the number of training samples).
    total_steps = len(train_dataloader) * epochs

    # Create the learning rate scheduler.
    schedulerOld = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps)
    
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps, num_cycles = 5)
    
    loss_fct = torch.nn.BCELoss()
    return model, datasetTrain, datasetVal, optimizer, schedulerOld

#from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
def prepare_for_trainingOld(input_idsTrain, attention_masksTrain, stance_labels_Train, ideology_labels_Train, input_idsVal, attention_masksVal, stance_labels_Val, ideology_labels_Val, modelUsed, batch_size=16, epochs = 50, num_warmup_steps=0, learning_rate=5e-5):
    # Combine the training inputs into a TensorDataset.

    from transformers import BertForSequenceClassification, AdamW, BertConfig, RobertaConfig, AutoModelWithLMHead
    from transformers import DistilBertForSequenceClassification, RobertaForSequenceClassification
    
    from torch.utils.data import DataLoader, RandomSampler
    
    t_train_stance, t_train_ideology, t_train_mmd_symbol, t_train_mmd_symbol_ = preprocess_ideology(stance_labels_Train, ideology_labels_Train)
    
    datasetTrain = TensorDataset(input_idsTrain, attention_masksTrain, t_train_stance, t_train_ideology, t_train_mmd_symbol, t_train_mmd_symbol_)

    # Combine the training inputs into a TensorDataset.
    t_val_stance, t_val_ideology, t_val_mmd_symbol, t_val_mmd_symbol_  = preprocess_ideology(stance_labels_Val, ideology_labels_Val)
    
    
    datasetVal = TensorDataset(input_idsVal, attention_masksVal, t_val_stance, t_val_ideology, t_val_mmd_symbol, t_val_mmd_symbol_)
    
    model = IdeologyDetectionClass(modelUsed)

    # Tell pytorch to run this model on the GPU.
    model.cuda()

    # Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
    # I believe the 'W' stands for 'Weight Decay fix"
    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-3,
                  correct_bias=True
               )

    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=8
        )
    batch_size = batch_size


    from transformers import get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup

    # Number of training epochs. The BERT authors recommend between 2 and 4. 
    # We chose to run for 4, but we'll see later that this may be over-fitting the
    # training data.
    epochs = epochs

    # Total number of training steps is [number of batches] x [number of epochs]. 
    # (Note that this is not the same as the number of training samples).
    total_steps = len(train_dataloader) * epochs

    # Create the learning rate scheduler.
    schedulerOld = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps)
    
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, # Default value in run_glue.py
                                            num_training_steps = total_steps, num_cycles = 5)
    
    loss_fct = torch.nn.BCELoss()
    return model, datasetTrain, datasetVal, optimizer, schedulerOld, loss_fct

    return model, datasetTrain, datasetVal, optimizer, schedulerOld

def return_batches_datasets(datasetTrain, datasetVal, batch_size = 16):
    from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
        
    # Create the DataLoaders for our training and validation sets.
    # We'll take training samples in random order. 
    train_dataloader = DataLoader(
            datasetTrain,  # The training samples.
            sampler =  RandomSampler(datasetTrain), # Select batches randomly
            batch_size = batch_size, # Trains with this batch size., 
            num_workers=0
        )

    # For validation the order doesn't matter, so we'll just read them sequentially.
    validation_dataloader = DataLoader(
            datasetVal, # The validation samples.
            sampler = SequentialSampler(datasetVal), # Pull out batches sequentially.
            batch_size = batch_size, # Evaluate with this batch size.
            num_workers=0
        )
    
    
    #validation_dataloader = DataLoader(
    #        datasetVal, # The validation samples.
    #        sampler = SequentialSampler(datasetVal), # Pull out batches sequentially.
    #        batch_size = batch_size, # Evaluate with this batch size.
    #        num_workers=0, drop_last=True
    #)
    
    return train_dataloader, validation_dataloader

def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)


from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, RandomSampler
#from tensorboardX import SummaryWriter
from sklearn.metrics import confusion_matrix
#import EarlyStopping
def train_stance_ideology_ambigious(train_nums, val_nums, train_nums_ideology, val_nums_ideology, model_save_path, 
                                    model, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    pro_val_num = val_nums[0]
    agst_val_num = val_nums[1]
    neut_val_num = val_nums[2] + 0.01
    notrel_val_num = val_nums[3]
    
    stance_all_num = pro_val_num + agst_val_num + neut_val_num + notrel_val_num
    
    con_val_num = 0.1
    lib_val_num = 0.1
    na_val_num = 0.1
    
    con_train_num = 0.1
    lib_train_num = 0.1
    na_train_num = 0.1
    
    #con_train_num = train_nums_ideology[0]
    #lib_train_num = train_nums_ideology[1]
    #na_train_num = train_nums_ideology[2]
    
    my_max_train_stance = max(pro_val_num, agst_val_num, neut_val_num, notrel_val_num)
    my_max_train = max(con_train_num, lib_train_num, na_train_num)
    
    #con_val_num = val_nums_ideology[0]
    #lib_val_num = val_nums_ideology[1]
    #na_val_num = val_nums_ideology[2]
    
    my_max = max(con_val_num, lib_val_num, na_val_num)
    
    ideology_all_num = con_val_num + lib_val_num + na_val_num
    
    writer = SummaryWriter()
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 4
    loss_fct_relatedness = torch.nn.BCEWithLogitsLoss()
    
    loss_fct_stance = torch.nn.CrossEntropyLoss()
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    
    alpha = 1.5
    beta = 1e-3
    theta = 0
    gamma = 0
    
    batch_size_max_once = 16

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
    
    
            
    if continue_train:    
        #'./model_save/fnc/model_emergentbert_epoch90_withoutsep_serp.t7'
        checkpoint = torch.load(model_save_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    
     #pos_weight=torch.FloatTensor ([28.36 / 0.5090]
    
     #pos_weight = torch.tensor([1.0, 1.0, 1.0])
     #pos_weight = pos_weight.to(device)
     #criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    weights_ideology = torch.tensor([my_max_train/con_train_num, my_max_train/lib_train_num, my_max_train/na_train_num]).to(device)   
    weights_stance = torch.tensor([my_max_train_stance/pro_val_num, my_max_train_stance/agst_val_num, my_max_train_stance/neut_val_num, my_max_train_stance/notrel_val_num]).to(device) 
    loss_fct_relatedness_weighted = torch.nn.BCEWithLogitsLoss(pos_weight = weights_stance)
    loss_fct_ideology_weighted = torch.nn.BCEWithLogitsLoss(pos_weight = weights_ideology)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(epoch_start, epoch_start + epochs):
        
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.

        #print("")
        #print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
        #print('Training...')

        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0
        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        #train_size = len(train_dataloader)
        
        if batch_epoch_count % 500 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1

        #train_size = len(train_dataloader) / float(accumulation_steps)
        
        print("Batch Size: " + str(batch_size))
        print(float(accumulation_steps))
        
        #print("Learning rate: ", scheduler.get_last_lr())
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_relatedness = batch[2].to(device)
            b_labels = batch[3].to(device)
            b_mmd_symbol = batch[4].to(device)
            b_mmd_symbol_ = batch[5].to(device)
            b_existedstances = batch[6].to(device)
            b_ideologies = batch[7].to(device)
        
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            #mmd_loss, P_relatedness, P_stance, P_existedstance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
            P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
                
                
                
            #relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
            stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
            #existedstance_loss = loss_fct_relatedness(P_existedstance, b_existedstances.float())

            
            #loss = alpha * stance_loss + theta * existedstance_loss + beta * mmd_loss + relatedness_loss
            loss = stance_loss
            loss = loss / accumulation_steps 
            total_train_loss += loss.item()
                
            loss.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                    

                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                    #torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                    
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                #print("Learning Rate: ", optimizer.param_groups["lr"])
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()       
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        avg_train_loss = total_train_loss / len(train_dataloader) * accumulation_steps
        #avg_train_loss = total_train_loss / len(train_dataloader)
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        #print("")
        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        print("  Training epoch took: {:}".format(training_time))
        
        
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.

        #print("")
        #print("Running Validation...")

        t1 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_true_eval_stance = 0
        total_true_eval_ideology = 0
        total_eval_loss = 0
        nb_eval_steps = 0
        
        agree_val_true = 0
        disagree_val_true = 0 
        discuss_val_true = 0 
        unrelated_val_true = 0
        
        con_val_true = 0
        lib_val_true = 0
        na_val_true = 0
        
        total_true = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            b_relatedness = batch[2].to(device)
            b_labels = batch[3].to(device)
            b_mmd_symbol = batch[4].to(device)
            b_mmd_symbol_ = batch[5].to(device)
            b_existedstances = batch[6].to(device)
            b_ideologies = batch[7].to(device)
        
            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.

                #mmd_loss, P_relatedness, P_stance, P_existedstance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)
                P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask, mmd_pl = b_mmd_symbol, mmd_pl_ = b_mmd_symbol_)

                
                #CrossEntropy Loss
                #relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
                stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
                #existedstance_loss = loss_fct_relatedness(P_existedstance, b_existedstances.float())
                
                #loss_val = alpha * stance_loss + beta * mmd_loss + relatedness_loss
                loss_val = stance_loss
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                #P_relatedness = P_relatedness.to('cpu')
                #b_relatedness = b_relatedness.to('cpu')
                P_stance = P_stance.to('cpu')
                b_labels = b_labels.to('cpu')
                #P_existedstance = P_existedstance.to('cpu')
                #b_existedstances = b_existedstances.to('cpu')
                
                

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                #total_eval_accuracy += predict(P_relatedness, P_stance, b_labels)

                #acc_list = predict_classwise_stance_ideology(P_relatedness, P_stance, P_existedstance, b_labels)
                acc_list = predict_classwise_stance_ideology_bert(P_stance, b_labels)
                total_true_eval_stance += acc_list[0]
                ###
                agree_val_true += acc_list[1]
                disagree_val_true += acc_list[2]
                discuss_val_true += acc_list[3]
                unrelated_val_true += acc_list[4]
                
                total_true_eval_ideology += acc_list[5]
                con_val_true += acc_list[6]
                lib_val_true += acc_list[7]
                na_val_true += acc_list[8]
                
                predict_labels = acc_list[9]
                
                                
                ##print("Batch Next")
                #for idx in range(0, len(P_stance)):
                    
                    #print(P_stance[idx], b_labels[idx], acc_list[9][idx]) 

        # Report the final accuracy for this validation run.
        avg_val_accuracy_stance = total_true_eval_stance / stance_all_num
        avg_val_accuracy_ideology = total_true_eval_ideology / ideology_all_num
        print("Avg Val Accuracy Stance: {0:.6f}".format(avg_val_accuracy_stance))
        print("Avg Val Accuracy Ideology: {0:.6f}".format(avg_val_accuracy_ideology))
        print("Total True")
        print(total_true)
        print("*************")
        avg_val_agree_accuracy = agree_val_true / pro_val_num
        print("Avg Val Agree Accuracy: {0:.6f}".format(avg_val_agree_accuracy))
        avg_val_disagree_accuracy = disagree_val_true / agst_val_num
        print("Avg Val Disagree Accuracy: {0:.6f}".format(avg_val_disagree_accuracy))
        avg_val_discuss_accuracy = discuss_val_true / neut_val_num
        print("Avg Val Discuss Accuracy: {0:.6f}".format(avg_val_discuss_accuracy))
        avg_val_unrelated_accuracy = unrelated_val_true / notrel_val_num
        print("Avg Val Unrelated Accuracy: {0:.6f}".format(avg_val_unrelated_accuracy))
        
        relative_score = 0.25*avg_val_unrelated_accuracy + 0.75*(avg_val_agree_accuracy + avg_val_disagree_accuracy + avg_val_discuss_accuracy)/3
        
        print("*****************")
        print("Relative score: {0:.6f}".format(relative_score))
        print("*****************")
        print("-------------")
        avg_val_con_accuracy = con_val_true / con_val_num
        print("Avg Val Con Accuracy: {0:.6f}".format(avg_val_con_accuracy))
        avg_lib_accuracy = lib_val_true / lib_val_num
        print("Avg Val Lib Accuracy: {0:.6f}".format(avg_lib_accuracy))
        avg_na_discuss_accuracy = na_val_true / na_val_num
        print("Avg Val NA Accuracy: {0:.6f}".format(avg_na_discuss_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
        
        print("Total Validation loss", total_eval_loss)
        print("Len-validation loader", len(validation_dataloader))
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t1)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        print("  Validation took: {:}".format(validation_time))

        #avg_val_accuracy_ideology = 0
        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Stance Accur.': avg_val_accuracy_stance,
            'Valid. Ideology Accur.': avg_val_accuracy_ideology,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
        
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy_stance, avg_val_accuracy_ideology, model_save_state, model_save_path, model)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max_stance

    return training_stats, last_epoch, min_val_loss, max_val_acc

from torch.utils.tensorboard import SummaryWriter

#import EarlyStopping
def train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    loss_fct = torch.nn.BCELoss()
    create_determinism(seedVal)
    
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 2
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-2
    
    batch_size_max_once = 16    
    

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
          
    continue_train = False
    if continue_train:
        checkpoint = torch.load('models/2_a/')
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(0, epochs):
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.


        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        
        
        if batch_epoch_count % 200 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1
        
        train_size = len(train_dataloader) / accumulation_steps
        
        print("Batch Size: " + str(batch_size))
        print(accumulation_steps)
        
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)      
            ideology_loss = loss_fct(P_ideology, b_ideology.float())

            loss = ideology_loss

            #loss = torch.sum(loss, dim=0)


            # Accumulate the training loss over all of t0e batches so that we can
            # calculate the average loss at the end. `loss` is a Tensor containing a
            # single value; the `.item()` function just returns the Python value 
            # from the tensor.
            #loss_train = loss
            loss_train = loss / accumulation_steps
            # Calculate the average loss over all of the batches.
            
            #loss_length = torch.numel(loss_train)
            #fill_length = batch_size_max_once-loss_length
            #cat_tensor = torch.zeros(fill_length, device=device)

            #if loss_length < batch_size_max_once:
                #loss_train = torch.cat([loss_train, cat_tensor], dim=0)
                
            mini_batch_avg_loss += loss_train.item()
            
            # Perform a backward pass to calculate the gradients.
            loss_train.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()
                #total_train_loss = 0           
                
                total_train_loss += mini_batch_avg_loss
                mini_batch_avg_loss = 0
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        
        avg_train_loss = total_train_loss / train_size
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.


        t0 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_eval_accuracy = 0
        total_eval_loss = 0
        total_eval_stanceloss = 0
        total_eval_ideologicalloss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            

            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.
                
                
                P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)

                ideology_loss = loss_fct(P_ideology, b_ideology.float())

                loss_val = ideology_loss
                #loss_val = torch.sum(loss, dim=0).item()
                
                #logits = model(input_ids = b_input_ids,attention_mask=b_input_mask)
                
                #loss = loss_function(logits, b_labels)
            
                # Accumulate the validation loss.
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                P_ideology = P_ideology.to('cpu')
                b_ideology = b_ideology.to('cpu')

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                total_eval_accuracy += predict_binary(P_ideology, b_ideology)
        

        # Report the final accuracy for this validation run.
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("Avg Val Accuracy: {0:.6f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        #print("  Validation took: {:}".format(validation_time))

        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
    
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy, model_save_state, model_save_path)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max

    return training_stats, last_epoch, min_val_loss, max_val_acc

from torch.utils.tensorboard import SummaryWriter

#import EarlyStopping
def train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer, scheduler, patience, verbose, delta, seedVal, continue_train = False):
    
    #loss_fct = torch.nn.BCEWithLogitsLoss()
    loss_fct = torch.nn.BCELoss()
    create_determinism(seedVal)
    
    min_val_loss = 100
    
    relatedness_size = 2
    classes_size = 2
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-2
    
    batch_size_max_once = 16    
    

    if batch_size < batch_size_max_once:
        batch_size_max_once = batch_size
        
    accumulation_steps = batch_size/batch_size_max_once
    
    es = EarlyStopping(patience,verbose, delta)
    writer = SummaryWriter()

    # We'll store a number of quantities such as training and validation loss, 
    # validation accuracy, and timings.
    training_stats = []

    # Measure the total training time for the whole run.
    total_t0 = time.time()
    train_dataloader, validation_dataloader = return_batches_datasets(datasetTrain, datasetVal, batch_size_max_once)
    
    epoch_start = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if torch.cuda.is_available():
        #multi-gpu
        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            model = torch.nn.DataParallel(model)
            
    print(device)
            
    if continue_train:
        checkpoint = torch.load('models/2_a/')
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    # For each epoch...
    batch_epoch_count = 1
    for epoch_i in range(0, epochs):
        print("---------Epoch----------" + str(epoch_i))
        
        # ========================================
        #               Training
        # ========================================
    
        # Perform one full pass over the training set.


        # Measure how long the training epoch takes.
        t0 = time.time()

        # Reset the total loss for this epoch.
        total_train_loss = 0

        # Put the model into training mode. Don't be mislead--the call to 
        # `train` just changes the *mode*, it doesn't *perform* the training.
        # `dropout` and `batchnorm` layers behave differently during training
        # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        # For each batch of training data...
        mini_batch_avg_loss = 0
        
        
        if batch_epoch_count % 200 == 0:
            batch_size = batch_size*2
            accumulation_steps = int(batch_size/batch_size_max_once)
        batch_epoch_count = batch_epoch_count + 1
        
        train_size = len(train_dataloader) / accumulation_steps
        
        print("Batch Size: " + str(batch_size))
        print(accumulation_steps)
        
        for step, batch in enumerate(train_dataloader):
            elapsed = format_time(time.time() - t0)
        
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            # Perform a forward pass (evaluate the model on this training batch).
            # The documentation for this `model` function is here: 
            # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
            # It returns different numbers of parameters depending on what arguments
            # arge given and what flags are set. For our useage here, it returns
            # the loss (because we provided labels) and the "logits"--the model
            # outputs prior to activation.

            P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)      
            ideology_loss = loss_fct(P_ideology, b_ideology.float())

            loss = ideology_loss

            #loss = torch.sum(loss, dim=0)


            # Accumulate the training loss over all of t0e batches so that we can
            # calculate the average loss at the end. `loss` is a Tensor containing a
            # single value; the `.item()` function just returns the Python value 
            # from the tensor.
            #loss_train = loss
            loss_train = loss / accumulation_steps
            # Calculate the average loss over all of the batches.
            
            #loss_length = torch.numel(loss_train)
            #fill_length = batch_size_max_once-loss_length
            #cat_tensor = torch.zeros(fill_length, device=device)

            #if loss_length < batch_size_max_once:
                #loss_train = torch.cat([loss_train, cat_tensor], dim=0)
                
            mini_batch_avg_loss += loss_train.item()
            
            # Perform a backward pass to calculate the gradients.
            loss_train.backward()
            if (step+1) % accumulation_steps == 0:             # Wait for several backward steps
                # Clip the norm of the gradients to 1.0.
                # This is to help prevent the "exploding gradients" problem.
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                
                # Update parameters and take a step using the computed gradient.
                # The optimizer dictates the "update rule"--how the parameters are
                # modified based on their gradients, the learning rate, etc.
                optimizer.step()

                # Update the learning rate.
                scheduler.step()
                
                #for param_group in optimizer.param_groups:
                
                                
                # Always clear any previously calculated gradients before performing a
                # backward pass. PyTorch doesn't do this automatically because 
                # accumulating the gradients is "convenient while training RNNs". 
                # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
                model.zero_grad()
                optimizer.zero_grad()
                #total_train_loss = 0           
                
                total_train_loss += mini_batch_avg_loss
                mini_batch_avg_loss = 0
    
        print("Learning rate: ", scheduler.get_last_lr())
        # Calculate the average loss over all of the batches.
        
        avg_train_loss = total_train_loss / train_size
    
        # Measure how long this epoch took.
        training_time = format_time(time.time() - t0)

        print("  Average training loss: {0:.6f}".format(avg_train_loss))
        
        # ========================================
        #               Validation
        # ========================================
        # After the completion of each training epoch, measure our performance on
        # our validation set.


        t0 = time.time()

        # Put the model in evaluation mode--the dropout layers behave differently
        # during evaluation.
        model.eval()

        # Tracking variables 
        total_eval_accuracy = 0
        total_eval_loss = 0
        total_eval_stanceloss = 0
        total_eval_ideologicalloss = 0
        nb_eval_steps = 0

        # Evaluate data for one epoch
        for batch in validation_dataloader:
        
            # Unpack this training batch from our dataloader. 
            #
            # As we unpack the batch, we'll also copy each tensor to the GPU using 
            # the `to` method.
            #
            # `batch` contains three pytorch tensors:
            #   [0]: input ids 
            #   [1]: attention masks
            #   [2]: labels 
            b_input_ids = batch[0].to(device)
            b_input_mask = batch[1].to(device)
            #b_stance = batch[2].to(device)
            b_ideology = batch[2].to(device)
            #b_mmd_symbol = batch[4].to(device)
            #b_mmd_symbol_ = batch[5].to(device)
            
            

            # Tell pytorch not to bother with constructing the compute graph during
            # the forward pass, since this is only needed for backprop (training).
            with torch.no_grad():        

                # Forward pass, calculate logit predictions.
                # token_type_ids is the same as the "segment ids", which 
                # differentiates sentence 1 and 2 in 2-sentence tasks.
                # The documentation for this `model` function is here: 
                # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification
                # Get the "logits" output by the model. The "logits" are the output
                # values prior to applying an activation function like the softmax.
                
                
                P_ideology = model(input_ids = b_input_ids, attention_mask = b_input_mask)

                ideology_loss = loss_fct(P_ideology, b_ideology.float())

                loss_val = ideology_loss
                #loss_val = torch.sum(loss, dim=0).item()
                
                #logits = model(input_ids = b_input_ids,attention_mask=b_input_mask)
                
                #loss = loss_function(logits, b_labels)
            
                # Accumulate the validation loss.
                total_eval_loss += loss_val.item()

                # Move logits and labels to CPU
                P_ideology = P_ideology.to('cpu')
                b_ideology = b_ideology.to('cpu')

                # Calculate the accuracy for this batch of test sentences, and
                # accumulate it over all batches.
                total_eval_accuracy += predict_binary(P_ideology, b_ideology)
        

        # Report the final accuracy for this validation run.
        avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
        print("Avg Val Accuracy: {0:.6f}".format(avg_val_accuracy))

        # Calculate the average loss over all of the batches.
        avg_val_loss = total_eval_loss / len(validation_dataloader)
    
        # Measure how long the validation run took.
        validation_time = format_time(time.time() - t0)
        
        if avg_val_loss < min_val_loss:
            min_val_loss = avg_val_loss
    
        print("Avg Validation Loss: {0:.6f}".format(avg_val_loss))
        #print("  Validation took: {:}".format(validation_time))

        # Record all statistics from this epoch.
        training_stats.append(
            {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
            }
        )
    
        model_save_state = {
            'epoch': epoch_i + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
            }
    
        es.__call__(avg_val_loss, avg_val_accuracy, model_save_state, model_save_path)
        last_epoch = epoch_i + 1
        if es.early_stop == True:
            break  # early stop criterion is met, we can stop now

    print("")
    print("Training complete!")

    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    
    
    min_val_loss = es.val_loss_min
    max_val_acc = es.val_acc_max

    return training_stats, last_epoch, min_val_loss, max_val_acc

def print_summary(training_stats):
    # Display floats with two decimal places.
    pd.set_option('precision', 4)
    
    pd.set_option('display.max_rows', 500)
    pd.set_option('display.max_columns', 500)

    # Create a DataFrame from our training statistics.
    df_stats = pd.DataFrame(data=training_stats)

    # Use the 'epoch' as the row index.
    df_stats = df_stats.set_index('epoch')

    # A hack to force the column headers to wrap.
    #df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])


    # Display the table.
    return df_stats

def plot_results(df_stats, last_epoch):
    # Use plot styling from seaborn.
    sns.set(style='darkgrid')

    # Increase the plot size and font size.
    sns.set(font_scale=1.5)
    plt.rcParams["figure.figsize"] = (12,6)
    
    plot1 = plt.figure(1)
    
    plt.plot(df_stats['Training Loss'], 'b-o', label="Training_Loss")
    plt.plot(df_stats['Valid. Loss'], 'g-o', label="Val_Loss")

    # Label the plot.
    plt.title("Training & Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    #plt.autoscale(enable=True, axis='x')

    x_ticks = []
    for currEpoch in range(1, last_epoch+1):
        x_ticks.append(currEpoch)
    #plt.xticks(x_ticks)
    plt.xticks(rotation=90)

    plt.show()

def run_wholeprocess_fnc(tokenizer, model_current, train_path, val_path, max_len, doc_stride, batch_size, num_warmup_steps, learning_rate, seedVal):

#--------------LOAD DATASETS--------------#

    model_save_path = './models/BERT_SERP/bert_titleonly_ambigious'  
    df = load_dataset_ambigious("./dataset/batches_cleaned/stance/FullDataset_16.09.2021.tsv")
    
    #trainpath = "./dataset/ideology/train_new.tsv"
    #valPath = "./dataset/ideology/val_new.tsv"
    #testPath = "./dataset/ideology/test_new.tsv"
    
    trainPer = 0.8
    valPer = 0.2
    testPer = 0.2
    
    df, dfVal, dfTest = sample_dataset_stance(df, seedVal)
    
    #create_new_splits_and_writethem_to_csvfiles
    #create_train_val_test_split(trainpath, valPath, testPath, testPer, valPer, seedVal)
    
    ##df = load_dataset(trainpath)
    #dfVal = load_dataset(valPath)
    #dfTest = load_dataset(testPath)
    
    # Report the number of sentences.
    print('Number of training sentences: {:,}'.format(df.shape[0]))
    print('Number of val sentences: {:,}'.format(dfVal.shape[0]))
    print('Number of test sentences: {:,}'.format(dfTest.shape[0]))

    sentencesQueryTitle_Train = []
    sentencesQueryTitleCont_Train = []
    stances_Train = []
    labels_Train = []
    
    #--------------DATASETS-------------#

    #sentencesQueryTitle_Train, sentencesQueryTitleCont_Train, sentencesQueryTitleStance_Train, sentencesQueryTitleStanceCont_Train, stances_Train, labels_Train = generate_datasets_ideology (df, tokenizer)
    sentencesQueryTitle_Train, sentencesQueryTitleCont_Train, labels_Train = generate_datasets_ambigious(df, tokenizer)
    
    
    
    sentencesQueryTitle_Val = []
    sentencesQueryTitleCont_Val = []
    stances_Val = []
    labels_Val = []

 
    #sentencesQueryTitle_Val, sentencesQueryTitleCont_Val, sentencesQueryTitleStance_Val, sentencesQueryTitleStanceCont_Val, stances_Val, labels_Val = generate_datasets_ideology (dfVal, tokenizer)
    sentencesQueryTitle_Val, sentencesQueryTitleCont_Val, labels_Val = generate_datasets_ambigious(dfVal, tokenizer)
    
    sentencesQueryTitle_Test = []
    sentencesQueryTitleCont_Test = []
    stances_Test = []
    labels_Test = []
   
    #sentencesQueryTitle_Test, sentencesQueryTitleCont_Test, sentencesQueryTitleStance_Test, sentencesQueryTitleStanceCont_Test, stances_Test, labels_Test = generate_datasets_ideology (dfTest, tokenizer)
    sentencesQueryTitle_Test, sentencesQueryTitleCont_Test, labels_Test = generate_datasets_ambigious(dfTest, tokenizer)
    
    print(sentencesQueryTitle_Train[0])

    #--------------DATASETS-------------#

    all_input_ids_Train, all_input_masks_Train  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Train, max_len, doc_stride)
    all_input_ids_Val, all_input_masks_Val  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Val, max_len, doc_stride)
    all_input_ids_Test, all_input_masks_Test  = preprocessing_for_bert(tokenizer, sentencesQueryTitleCont_Test, max_len, doc_stride)
    
    #all_input_ids_Train, all_input_masks_Train, stance_labels_Train, ideology_labels_Train = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Train, stances_Train, labels_Train, max_len, doc_stride) #train
    #all_input_ids_Val, all_input_masks_Val, stance_labels_Val, ideology_labels_Val = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Val, stances_Val, labels_Val, max_len, doc_stride) #val
    #all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test = transform_sequences_longer_ideology(tokenizer, sentencesQueryTitleCont_Test, stances_Test, labels_Test, max_len, doc_stride) #test

    model, datasetTrain, datasetVal, optimizer, scheduler = prepare_for_training_ambigious(all_input_ids_Train, all_input_masks_Train, labels_Train, all_input_ids_Val,
                                                                                                               all_input_masks_Val, labels_Val, model_current, batch_size, epochs, num_warmup_steps, learning_rate)    
    training_stats, last_epoch, min_val_loss, max_val_acc = train_stance(model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer,
                                                                          scheduler, patience, verbose, delta, seedVal)
    
    
    avg_test_loss, avg_test_acc = run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, labels_Test, batch_size)
    df_stats = print_summary(training_stats)
    plot_results(df_stats, last_epoch)
    
    #--------------TRAINING-------------#
    batch_size_cuda = 16
    if batch_size < 16:
        batch_size_cuda = batch_size
        
    num_iterations = 5
    total_val_loss = 0.0
    total_val_acc = 0.0
    total_test_loss = 0.0
    total_test_acc = 0.0
    
    for i in range(0, num_iterations):
        
        value = randint(0, 100)
        seedVal = value
        print("******************")
        print("This is the iteration " + str(i))
        
        model_save_path = "model_save/ideology/model_news2a_qtitle.t7" + str(i)

        model, datasetTrain, datasetVal, optimizer, scheduler = prepare_for_training(all_input_ids_Train, all_input_masks_Train, stance_labels_Train, ideology_labels_Train, all_input_ids_Val,
                                                                                                               all_input_masks_Val, stance_labels_Val, ideology_labels_Val, model_current, batch_size_cuda, epochs, num_warmup_steps, learning_rate)    
        training_stats, last_epoch, min_val_loss, max_val_acc = train_stance (model_save_path, model, tokenizer, datasetTrain, datasetVal, epochs, batch_size, optimizer,
                                                                          scheduler, patience, verbose, delta, seedVal)

        avg_test_loss, avg_test_acc = run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size)
        df_stats = print_summary(training_stats)
        plot_results(df_stats, last_epoch)
        
        total_val_loss += min_val_loss
        total_val_acc += max_val_acc
        total_test_loss += avg_test_loss
        total_test_acc += avg_test_acc

        print('Min Val Loss: ' + str(min_val_loss))
        print('Max Val Acc: ' + str(max_val_acc))
        print('Test Loss: ' + str(avg_test_loss))
        print('Test Acc: ' + str(avg_test_acc))
        
        
    print("******************")
    print('Avg Min Val Loss: ' + str(total_val_loss/num_iterations))
    print('Avg Max Val Acc: ' + str(total_val_acc/num_iterations))
    print('Avg Test Loss: ' + str(total_test_loss/num_iterations))
    print('Avg Test Acc: ' + str(total_test_acc/num_iterations))
    
    
    #model_to_save.save_pretrained('model_save')
    #tokenizer.save_pretrained('model_save')

    # Good practice: save your training arguments together with the trained model
    #torch.save(args, os.path.join('model_save', 'training_args.bin'))
    #model_args = str(max_len) + '_' + str(doc_stride) + '_' + str(batch_size) + "_" + str(learning_rate) + "_warmup" + str(num_warmup_steps) + "_seedVal" + str(seedVal)
    #model_path = model_save_path + '/model_' + model_args
    #model.save_pretrained(model_save_path)
    #torch.save(model.state_dict(), model_path)

def create_determinism(seedVal):
    import os
    torch.manual_seed(seedVal)
    torch.cuda.manual_seed_all(seedVal)  
    torch.cuda.manual_seed(seedVal)
    np.random.seed(seedVal)
    random.seed(seedVal)
    #os.environ['PYTHONHASHSEED'] = str(seedVal)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False

    return avg_test_loss, avg_test_accuracy

from torch.utils.data import DataLoader, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig

def run_test_stance(model_savepath, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size = 16):
    #loss_fct = torch.nn.BCELoss()
    loss_fct_relatedness = torch.nn.BCEWithLogitsLoss()
    
    t_test_relatedness, t_test_stance, t_test_mmd_symbol, t_test_mmd_symbol_ = preprocess_fnc(stance_labels_Test)
    # Create the DataLoader.
    prediction_data = TensorDataset(all_input_ids_Test, all_input_masks_Test, t_test_relatedness, t_test_stance, t_test_mmd_symbol, t_test_mmd_symbol_)
    prediction_sampler = SequentialSampler(prediction_data)
    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size, num_workers=0)
    
    model_current = 'bert-base-uncased'
    tokenizer = load_tokenizer(model_current)
        
    model = StanceDetectionClass(model_current)
    checkpoint = torch.load(model_savepath)
    model.load_state_dict(checkpoint['state_dict'])    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-5,
                  correct_bias=True
    )
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch']
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    
    #model.cuda()
    # Put model in evaluation mode
    model.eval()

    # Tracking variables
    total_test_loss = 0.0
    
    total_test_accuracy = 0.0
    predictions , true_labels = [], []
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-3
    # Predict 
    for batch in prediction_dataloader:
      #Add batch to GPU
        
        #batch = tuple(t.to(device) for t in batch)
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_relatedness = batch[2].to(device)
        b_labels = batch[3].to(device)
        b_mmd_symbol = batch[4].to(device)
        b_mmd_symbol_ = batch[5].to(device)
  
        # Telling the model not to compute or store gradients, saving memory and 
        # speeding up prediction
        with torch.no_grad():         
            # Forward pass, calculate logit predictions
            
            n1 = torch.sum(b_mmd_symbol, dim=0)
            n2 = torch.sum(b_mmd_symbol_, dim=0)
        
            aa = torch.reshape(b_mmd_symbol, (-1,1))
            bb = torch.reshape(b_mmd_symbol_, (-1,1))
            
            theta_d_layer, P_relatedness, P_stance = model(input_ids = b_input_ids, attention_mask = b_input_mask)
                
            if n1 == 0:
                d1 = torch.zeros(batch_size, 1, device = device)
            else:
                d1 = torch.div(torch.sum(theta_d_layer*aa, dim=1), n1)
                
            if n2 == 0:
                d2 = torch.zeros(batch_size, 1, device = device)
            else:
                d2 = torch.div(torch.sum(theta_d_layer*bb, dim=1), n2)
                    
                    
            mmd_loss = torch.sum(d1 - d2)
                
            
            relatedness_loss = loss_fct_relatedness(P_relatedness, b_relatedness.float())
            stance_loss = loss_fct_relatedness(P_stance, b_labels.float())
                
    
            loss_test = relatedness_loss + alpha * stance_loss - beta * mmd_loss
            total_test_loss += loss_test.item()
            
            # Move logits and labels to CPU
            P_relatedness = P_relatedness.to('cpu')
            b_relatedness = b_relatedness.to('cpu')
            P_stance = P_stance.to('cpu')
            b_labels = b_labels.to('cpu')

            total_test_accuracy += predict(P_relatedness, P_stance, b_labels)

    # Report the final accuracy for this validation run.
    avg_test_loss = total_test_loss / len(prediction_dataloader)
    avg_test_accuracy = total_test_accuracy / len(prediction_dataloader)

    return avg_test_loss, avg_test_accuracy

from torch.utils.data import DataLoader, SequentialSampler
from transformers import BertForSequenceClassification, AdamW, BertConfig

def run_test_ideology(model_save_path, all_input_ids_Test, all_input_masks_Test, stance_labels_Test, ideology_labels_Test, batch_size = 16):

    t_ideology_labels_test = preprocess_ideology_new(stance_labels_Test, ideology_labels_Test)
    # Create the DataLoader.
    prediction_data = TensorDataset(all_input_ids_Test, all_input_masks_Test, t_ideology_labels_test)
    prediction_sampler = SequentialSampler(prediction_data)
    prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)

    loss_fct = torch.nn.BCELoss()
    
    model_current = 'bert-base-uncased'
    tokenizer = load_tokenizer(model_current)
        
    model = IdeologyDetectionClass(model_current)
    checkpoint = torch.load(model_save_path)
    model.load_state_dict(checkpoint['state_dict'])    
    
    optimizer = AdamW(model.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  betas=(0.9, 0.999), 
                  eps=1e-08, 
                  weight_decay=1e-5,
                  correct_bias=True
    )
    optimizer.load_state_dict(checkpoint['optimizer'])
    epoch_start = checkpoint['epoch']
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    torch.cuda.empty_cache()
    model.to(device)
    optimizer_to(optimizer,device)
    

    model.to(device)
    # Put model in evaluation mode
    model.eval()

    # Tracking variables
    total_test_loss = 0.0
    
    total_test_accuracy = 0.0
    predictions , true_labels = [], []
    
    alpha = 1.3
    theta = 0.8
    beta = 1e-3
    # Predict 
    for batch in prediction_dataloader:
      #Add batch to GPU

        #batch = tuple(t.to(device) for t in batch)
        
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_ideologylabels = batch[2].to(device)
  
        # Telling the model not to compute or store gradients, saving memory and 
        # speeding up prediction
        with torch.no_grad():         
            
            
            # Forward pass, calculate logit predictions
            P_ideology = model(b_input_ids, attention_mask=b_input_mask)

            ideology_loss = loss_fct(P_ideology, b_ideologylabels.float())

            loss = ideology_loss

            #logits = outputs[0]

            # Move logits and labels to CPU
            P_ideology = P_ideology.detach().cpu()
            t_ideology_labels_test = b_ideologylabels.to('cpu')

            # Calculate the accuracy for this batch of test sentences, and
            # accumulate it over all batches.
            total_test_loss += loss.item()
            total_test_accuracy += predict_binary(P_ideology, t_ideology_labels_test)
        

    # Report the final accuracy for this validation run.
    avg_test_loss = total_test_loss / len(prediction_dataloader)
    
    avg_test_accuracy = total_test_accuracy / len(prediction_dataloader)
  
            # Store predictions and true labels
            #predictions.append(logits)
            #true_labels.append(label_ids)
    #print('Test Accuracy', avg_test_accuracy)

    return avg_test_loss, avg_test_accuracy

### import os
import string
import tensorflow as tf
import torch
import pandas as pd
import numpy as np
from random import randint
import random
import time
import datetime
from transformers import AutoModel
from transformers import DistilBertModel
from torch.utils.data import TensorDataset, random_split
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

#model = "bert-base-uncased"
train_path = './dataset/fnc/train'
val_path = './dataset/fnc/test'

model_save_path = './model_save/'
model_name = 'querytitle_model_base_9'

#device = run_utils()
        
model_base = 'bert-base-uncased'
model_roberta = "roberta-base"
model_finetuned = './models/2_a/'
model_finetuned2 = './models/2_b/'
model_tiny_bert = './models/tiny_bert/'
    
model_current = model_base
tokenizer = load_tokenizer(model_current)

max_len = 512
doc_stride = 128

batch_size = 16
epochs = 30
num_warmup_steps = 10
learning_rate = 2e-6

##-----Early Stopping
patience = 4000
verbose = True
delta = 0.000001
seedVal = 20

train_flag = True

if train_flag:
    run_wholeprocess_fnc(tokenizer, model_current, train_path, val_path, max_len, doc_stride, batch_size, num_warmup_steps, learning_rate,seedVal)  
else:
    avg_test_loss, avg_test_acc = run_onlytest_ideology(tokenizer, model_save_path + model_name, torch.nn.BCELoss(), device)
    print(avg_test_loss, avg_test_acc)






