In [1]:
%cd ../../

d:\online_predatory_conversation_detection


In [24]:
import torch
import pandas as pd

from transformers import AutoTokenizer, BertTokenizer, \
    DataCollatorForLanguageModeling, AutoModelForMaskedLM, \
    TrainingArguments, Trainer

from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from tqdm import tqdm
from multiprocessing import Pool

In [22]:
model_name = "distilroberta-base"

class MyDataset(Dataset):

    def __init__(self, df, tokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        
        self.input_ids, self.attention_mask = [None] * df.shape[0], [None] * df.shape[0]
        txts = df["text"].to_list()
        
        window_size = 512
        len_txts = len(txts)
        for start in tqdm(range(0, len_txts, window_size)):
            tmp = self.tokenizer(txts[start: min(start+window_size, len_txts)], truncation=True, padding="max_length", return_tensors='pt')
            self.input_ids[start: min(start+window_size, len_txts)] = tmp["input_ids"]
            self.attention_mask[start: min(start+window_size, len_txts)] = tmp["attention_mask"]
        print(start, len_txts)

        tmp = self.tokenizer(txts[start: min(start+window_size, len_txts)], truncation=True, padding="max_length", return_tensors='pt')
        self.input_ids[start: min(start+window_size, len_txts)] = tmp["input_ids"]
        self.attention_mask[start: min(start+window_size, len_txts)] = tmp["attention_mask"]

        self.attention_mask = torch.stack(self.attention_mask, dim=0)
        mask_filter = (self.attention_mask.sum(dim=1) <= 3)
        self.attention_mask = self.attention_mask[mask_filter]
        self.input_ids = torch.stack(self.input_ids, dim=0)
        print(f"before filter: {len(mask_filter)}, after filter: {len(mask_filter) - sum(mask_filter)}")
        self.input_ids = self.input_ids[mask_filter]
    
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {"input_ids": self.input_ids[idx], "attention_mask": self.attention_mask[idx]}


In [23]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset_path = "data/dataset-v2/train.csv"
df = pd.read_csv(dataset_path)
df = df[(df["nauthor"] >= 2) & (df["conv_size"] > 6)]
df["text"].fillna("", inplace=True)
traindf, testdf = train_test_split(df, test_size=0.1)

dataset = MyDataset(traindf, tokenizer)
test_set = MyDataset(testdf, tokenizer)
del df, testdf, traindf
# df = pd.read_csv(dataset_path)
# df["text"].fillna("", inplace=True)

# df.head()

100%|██████████| 1421/1421 [03:22<00:00,  7.03it/s]


727040 727190
before filter: 727190, after filter: 645075


In [21]:
filter_records = (dataset.attention_mask.sum(dim=1) <= 3)

tensor([False, False, False,  ...,  True, False, False])

In [72]:
# df = pd.read_csv(dataset_path)
aa = [None] * 10
aa[:3] = (tokenizer(df["text"].iloc[:3].tolist(), return_tensors='pt', truncation=True, padding=True)["input_ids"])
torch.stack(aa[:3],), aa[:3]
# df["text"].tolist()[:128]

(tensor([[    0,   725,  3019,     4,     2,     2],
         [    0,  3592,     4,     2,     2,     2],
         [    0, 11613,  2923,    62,   116,     2]]),
 [tensor([   0,  725, 3019,    4,    2,    2]),
  tensor([   0, 3592,    4,    2,    2,    2]),
  tensor([    0, 11613,  2923,    62,   116,     2])])

In [28]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [29]:
model = AutoModelForMaskedLM.from_pretrained(model_name)

In [32]:
training_args = TrainingArguments(
    output_dir="output-pretraining",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
    per_device_train_batch_size=16
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=test_set,
    data_collator=data_collator,
)


In [33]:
trainer.train()
trainer.save_model("output-pretraining")
tokenizer.save_pretrained("output-pretraining/tokenizer")

  0%|          | 0/338853 [55:46<?, ?it/s]
                                                       
  0%|          | 500/169428 [05:34<31:29:38,  1.49it/s] 

{'loss': 0.06, 'learning_rate': 1.9940977878508866e-05, 'epoch': 0.01}


                                                        
  1%|          | 1000/169428 [11:08<30:50:47,  1.52it/s]

{'loss': 0.0598, 'learning_rate': 1.9881955757017733e-05, 'epoch': 0.02}


                                                        
  1%|          | 1500/169428 [16:36<29:09:09,  1.60it/s]

{'loss': 0.06, 'learning_rate': 1.9822933635526598e-05, 'epoch': 0.03}


                                                         
  1%|          | 2000/169428 [21:53<29:13:21,  1.59it/s]

{'loss': 0.0588, 'learning_rate': 1.9763911514035462e-05, 'epoch': 0.04}


                                                         
  1%|▏         | 2500/169428 [27:22<30:31:13,  1.52it/s]

{'loss': 0.0593, 'learning_rate': 1.9704889392544326e-05, 'epoch': 0.04}


                                                        
  2%|▏         | 3000/169428 [32:45<28:47:20,  1.61it/s]  

{'loss': 0.0582, 'learning_rate': 1.964586727105319e-05, 'epoch': 0.05}


                                                        
  2%|▏         | 3500/169428 [38:04<28:55:08,  1.59it/s]  

{'loss': 0.0646, 'learning_rate': 1.9586845149562058e-05, 'epoch': 0.06}


                                                        
  2%|▏         | 4000/169428 [43:20<28:19:54,  1.62it/s]  

{'loss': 0.0629, 'learning_rate': 1.9527823028070922e-05, 'epoch': 0.07}


                                                        
  3%|▎         | 4500/169428 [48:35<28:33:57,  1.60it/s]  

{'loss': 0.0636, 'learning_rate': 1.946880090657979e-05, 'epoch': 0.08}


                                                        
  3%|▎         | 5000/169428 [53:50<29:15:13,  1.56it/s]  

{'loss': 0.0614, 'learning_rate': 1.9409778785088654e-05, 'epoch': 0.09}


                                                        
  3%|▎         | 5500/169428 [59:25<30:02:48,  1.52it/s]  

{'loss': 0.0626, 'learning_rate': 1.9350756663597518e-05, 'epoch': 0.1}


                                                          
  4%|▎         | 6000/169428 [1:04:48<28:18:12,  1.60it/s]

{'loss': 0.0624, 'learning_rate': 1.9291734542106385e-05, 'epoch': 0.11}


                                                          
  4%|▍         | 6500/169428 [1:10:17<29:19:56,  1.54it/s]

{'loss': 0.0627, 'learning_rate': 1.9232712420615246e-05, 'epoch': 0.12}


                                                          
  4%|▍         | 7000/169428 [1:15:47<29:29:27,  1.53it/s]

{'loss': 0.0618, 'learning_rate': 1.9173690299124114e-05, 'epoch': 0.12}


                                                          
  4%|▍         | 7500/169428 [1:21:16<29:12:42,  1.54it/s]

{'loss': 0.0617, 'learning_rate': 1.9114668177632978e-05, 'epoch': 0.13}


  5%|▍         | 7989/169428 [1:26:45<29:57:03,  1.50it/s]

KeyboardInterrupt: 