In [None]:
import sys
import os
import collections
import json
from itertools import cycle
from ast import literal_eval
from dataclasses import dataclass, asdict
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from pathlib import Path
from datetime import datetime
from tqdm.auto import tqdm, trange

import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch import optim
from torch.optim import lr_scheduler
import torchmetrics
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import *
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput

import matplotlib.pyplot as plt
from IPython.display import display

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# https://github.com/huggingface/transformers/issues/5486
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [None]:
result_folder = os.environ["scratch_result_folder"] if "scratch_result_folder" in os.environ else '../result'
scratch_data_folder = os.environ["scratch_data_folder"] if "scratch_data_folder" in os.environ else None
repo_folder = os.environ["style_models_repo_folder"] if "style_models_repo_folder" in os.environ else None
data_folder = f"{repo_folder}/data" if repo_folder else '../data'

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') 

In [None]:
@dataclass
class MyTrainingArgs:
    # dataset args # it's not very appropriate to put them here, especially the split
    dataset_idx: int = 0
#     split: str
        
    # model args
    base_model_name: str ='bert-base-uncased'
    freeze_bert: bool = 0
    use_pooler: bool = False
        
    # training args    
    num_epoch: int = 5
    lr: float = 5e-5
    num_warmup_steps = 500
    warmup_ratio = 0.1
    model_folder: str = None # if None, this will be inferred based on tasks
    model_name: str = None # if provide, use to name model_folder, otherwise use style to name model_folder
    loss_fn: str = None
    do_mlm: bool = False
    method: str = 'successive paragraph'
    # tempurature for cosine similarity. 
    # simcse uses 0.05. However, that's for crossentropy on single label classification. 
    # For multi-label where BCEloss(withlogits) is used, due to sigmoid, temp < 1 tends to make it learn slower
    cos_temp: float = 1. # not used
        
    # data loader args
    batch_size: int = 32
    max_length: int = 64
    shuffle: bool = False
    num_workers: int = 4
    data_limit: int = None # if not None, truncate dataset to keep only top {data_limit} rows
    
    # post training args
    save_best_only: bool = True
    load_best_at_end: bool = True
    early_stop_patience: int = 1
    
    def __post_init__(self):
        excute_time = datetime.now() 
        model_name = self.model_name if self.model_name else f"pan22-dataset{self.dataset_idx}"
        model_folder = f"{result_folder}/{model_name}/{excute_time.now().strftime('%Y%m%d-%H:%M:%S')}"
        self.model_folder = model_folder

In [None]:
class PastelDataset(Dataset): 
    # currently it's a Mapping-style dataset. Not sure if a Iterable-style dataset will be better
    # this works for standard class indices and also class probilities
    # limit: use to truncate dataset. This will drop rows after certain index. May influence label distribution.
    def __init__(self, split):
#         self.task = task
        self.split = split
        self.df = pd.read_csv(f"{data_folder}/pastel/processed/{self.split}/pastel.csv")
        self.df = self.df.dropna()
        self.df = self.df.reset_index(drop=True)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        dataslice = self.df.iloc[idx]
        item = {'text': dataslice['output.sentences']}
        return item


In [None]:
def collate_fn(batch):
    '''
    Create fake meaningless "changes" label. To match with model requirement only
    '''
    batch_out = {}
    texts = []
    for item in batch:
        texts.append(item['text'])
    
    batch_out = {k:v for k,v in tokenizer(text = texts, return_tensors='pt', padding=True, truncation=True, max_length=my_training_args.max_length).to(device).items()}
    batch_out['changes'] = [0] * (len(batch)-1)
    batch_out['texts'] = texts
    
    return batch_out

In [None]:
@dataclass
class SCDOutput(ModelOutput):
    loss: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    sent_embs: List[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [None]:
class BertForSCD(BertPreTrainedModel):
    def __init__(self, config, training_args):
        super().__init__(config)
        self.use_pooler = training_args.use_pooler
        self.basemodel = AutoModel.from_pretrained(training_args.base_model_name)
        self.do_mlm = training_args.do_mlm
#         self.cossim = Similarity(training_args.cos_temp)
        
        self.hidden1 = nn.Linear(2*768, 256)
        self.gelu = nn.GELU()
        self.hidden2 = nn.Linear(256, 2)
        
        if training_args.loss_fn == 'BCEWithLogitsLoss':
            self.loss_fn = nn.BCEWithLogitsLoss()
        elif training_args.loss_fn == 'MSELoss':
            self.loss_fn = nn.MSELoss()
        else:
            self.loss_fn = nn.CrossEntropyLoss()
        
        # mlm is not finished yet
#         if self.do_mlm: 
#             self.lm_head = BertLMPredictionHead(config)

    def forward(self, input_ids, token_type_ids, attention_mask, output_sent_embs=False, output_hidden_states=False, output_attentions=False, **kwargs):
        output = self.basemodel(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions)
        
        if self.use_pooler and ('pooler_output' in output):
            sent_emb = output['pooler_output']
        else:
            sent_emb = output['last_hidden_state'][:,0,:]
        
        concat_embs = torch.cat([sent_emb[:-1], sent_emb[1:]], axis=-1)
        
        logits = self.hidden2(self.gelu(self.hidden1(concat_embs)))
        
        # get style change labels
        labels = torch.LongTensor(kwargs['changes']).to(device)
        
        loss = self.loss_fn(logits, labels)
        
        return SCDOutput(loss=loss, logits=logits, sent_embs=sent_emb.detach(), hidden_states=output.hidden_states, attentions=output.attentions)
        

In [None]:
train_set = PastelDataset(split='train')
val_set = PastelDataset(split='valid')
train_loader = DataLoader(train_set, batch_size=32, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(val_set, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [None]:
model_folder = f"{result_folder}/PAN_SCD_CL/sp_run_7"
my_training_args = MyTrainingArgs()
config = AutoConfig.from_pretrained('bert-base-uncased') 
model = BertForSCD(config, my_training_args).to(device)
model.load_state_dict(torch.load(f"{model_folder}/pytorch_model.bin"))
model.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [None]:
examples = collections.defaultdict(list)
for i_iter,x in enumerate(train_loader):
    texts = x['texts']
    output = model(**x)
    prediction = output.logits.argmax(-1).detach().cpu()
    for i in torch.where(prediction==1)[0]:
        examples['has change'].append([texts[i], texts[i+1]])
    for i in torch.where(prediction==0)[0][::3]:
        examples['no change'].append([texts[i], texts[i+1]])
    if len(examples['has change']) > 100:
        break

In [None]:
output = model(**x)
prediction = output.logits.argmax(-1).detach().cpu()


In [None]:
prediction

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0])

In [None]:
torch.where(prediction==1)[0][::2]

tensor([12, 22])

In [None]:
x['texts']

['He was welcomed and got comfortable very soon',
 'The person went into the really cute bakery.',
 'The woodwinds were particularly great at this concert.',
 'We saw a cute cat in the window.',
 'We gathered in a crowded room to celebrate the life of our friend.',
 'The folks in the water participated in the boat portion of the race.',
 'Thanks for the beautiful birthday flowers.',
 'we came up to a big wall',
 'Her groom was nervous.',
 "It was a wonderful celebration of Nick's achievements.",
 'The story had a great location with an inviting entrance.',
 "The runners in the marathon run a very long way, but at least it's a scenic route that passes by the ocean.",
 'The visitors team was just as pumped as the home team!',
 'we had a great time visiting the homes.',
 'IT IS A NICE BULIDINGS',
 'This multi legged creature is scary to the little girl',
 'It took 5 hours.',
 'At the beginning of the school year, university students always have some things they need to shop for.',
 'The p