In [None]:
BASE_DIR = '' # Working dir
DATA_DIR = f'{BASE_DIR}data/'
MODELS_DIR = f'{BASE_DIR}models/'
EXPACE_DIR = f'{DATA_DIR}/expace-v1/'

# Load models

In [None]:
import torch
import json

In [None]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
device = 'cuda'
model_id = 'gpt2'

model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

# Load files

In [None]:
from os import listdir
import string

def read_file(file):
    with open(file, encoding="utf-8") as f:
        text = f.read()  
        return text

def count_sentences(files):
    total = 0
    for file in files:
        content = read_file(file)
        total += len([sent for sent in sent_tokenize(content) if len(sent) > MIN_SENT_LEN])
    return total           

expace_files = [EXPACE_DIR + file for file in listdir(EXPACE_DIR) if '.txt' in file]    

In [None]:
HYPER_PARAMS = {
    'epochs': 3,
    'lr': 3e-5,
    'batch_size': 8,
    'name': 'gpt2-expace-v1',
}
VAL_CHECK_INTERVAL = 0.3

# Prepare dataset

In [None]:
import torch
from nltk import sent_tokenize
from random import shuffle
import re
import string

REMOVE_CIT_MATH = False
REMOVE_NON_ASCII = False
MIN_WORDS = 10
USE_STRICT_PUNCT = False
STRICT_PUNCT = '*+,-./%\'"!~'

def valid_sentence(sentence):
    if USE_STRICT_PUNCT:
        pattern = f'{string.digits+string.ascii_letters+STRICT_PUNCT+string.whitespace}'
    else:
        pattern = f'{string.digits+string.ascii_letters+string.punctuation+string.whitespace}'
    return not re.search(fr'[^{pattern}]', sentence)

def load_sentences(datasets, files, max_sentences='all'):
    for file in files:
        content = read_file(file)
        sentences = sent_tokenize(content)
        
        shuffle(sentences)
        
        counter_add = 0
        
        for sentence in sentences:  
            if REMOVE_NON_ASCII and not valid_sentence(sentence):
                continue
            if len(sentence.split()) < MIN_WORDS:
                continue
                
            if REMOVE_CIT_MATH:
                if '[' in sentence or ']' in sentence:
                    continue

            if '\n' in sentence:
                continue                
            datasets.append(sentence) 
            counter_add += 1
            if max_sentences !='all' and max_sentences and counter_add >= max_sentences:
                break
                
def save_dataset(dataset, name):
    with open(name, 'w', encoding='UTF-8') as f:
        for item in dataset:
            f.write(f'{item}\n')

def build_preloaded_files(files, num_files, num_sents_per_file, output_folder, version_name):
    datasets = []
    
    shuffle(files)
    
    print('Loading sentences')
    if num_files == 'all':
        load_sentences(datasets, files, num_sents_per_file)
    else:
        load_sentences(datasets, files[:num_files], num_sents_per_file)        

    shuffle(datasets)

    print ('Total sentences', len(datasets))
    train_size = int(0.8 * len(datasets))
    val_size = int(0.1 * len(datasets))
    test_size = len(datasets) - val_size - train_size
    
    val_limit = train_size + val_size
    train_data_with_label = datasets[0:train_size]
    val_data_with_label = datasets[train_size:val_limit]
    test_data_with_label = datasets[val_limit:]

    save_dataset(train_data_with_label, f'{output_folder}/train-{num_files}-{num_sents_per_file}-{version_name}.txt')
    save_dataset(val_data_with_label, f'{output_folder}/val-{num_files}-{num_sents_per_file}-{version_name}.txt')
    save_dataset(test_data_with_label, f'{output_folder}/test-{num_files}-{num_sents_per_file}-{version_name}.txt')
    return train_data_with_label, val_data_with_label, test_data_with_label                

In [None]:
NUM_FILES = 'all'
SENTENCES_PER_FILE = 'all'
DATA_VERSION = 'v1'
train_data_with_label, val_data_with_label, test_data_with_label = build_preloaded_files(expace_files, NUM_FILES, SENTENCES_PER_FILE, f'{DATA_DIR}expace-sentences', DATA_VERSION)

In [None]:
from torch.utils.data import DataLoader, IterableDataset
from nltk import sent_tokenize

class LazyLoadFileLines(IterableDataset):
    def __init__(self, file):
        with open(file, encoding="utf-8") as f:
            lines = f.read().rstrip().splitlines()
        self.sentences = lines
        self.length = len(lines) 
        
    def __len__(self):
        return self.length
    
    def get_next_sentence(self):    
        for sent_index, sent in enumerate(self.sentences):
            yield self.tokenize(sent, sent_index)
    
    def tokenize(self, sentence, sent_index):
        encodings = tokenizer([sentence], truncation=True, padding='max_length', max_length=150)
        item = {key: torch.tensor(val[0], device='cuda') for key, val in encodings.items()}
        return item
    
    def collate(self, batch):
        print (len(batch))
        return batch
    
    def __iter__(self):
        return self.get_next_sentence()
    
def get_iteratable_dataset_lines(num_files, num_sents_per_file, base_folder, version_name):
    train_dataset = LazyLoadFileLines(f'{base_folder}train-{num_files}-{num_sents_per_file}-{version_name}.txt')
    val_dataset = LazyLoadFileLines(f'{base_folder}/val-{num_files}-{num_sents_per_file}-{version_name}.txt')
    test_dataset = LazyLoadFileLines(f'{base_folder}/test-{num_files}-{num_sents_per_file}-{version_name}.txt')

    file_name = f'train-{num_files}-{num_sents_per_file}-{version_name}.txt'
    return train_dataset, val_dataset, test_dataset, file_name

NUM_FILES = 'all'
SENTENCES_PER_FILE = 'all'
DATA_VERSION = 'v1'

train_dataset, val_dataset, test_dataset, file_name = get_iteratable_dataset_lines(NUM_FILES, SENTENCES_PER_FILE, f'{DATA)DIR}expace-sentences/', DATA_VERSION)

HYPER_PARAMS['file'] = file_name

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = HYPER_PARAMS['batch_size']

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size 
)

validation_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size 
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size 
)

# Define trainer

In [None]:
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.metrics import functional as FM
from transformers import AdamW

class LMHeadModel(pl.LightningModule):
    
    def __init__(self, model, tokenizer):
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.valid_acc = pl.metrics.Accuracy()
    
    def forward(self, x):
        self.model.eval()
        self.model.cuda()
        
        input_ids = tokenizer.encode(x, return_tensors='pt').to('cuda')
        outputs = self.model(input_ids)
        prob = F.softmax(outputs.logits.detach(), dim=1).cpu().numpy()[0].tolist()
        
        return {label: prob[index]  for index, label in enumerate(self.labels)}
            
    def training_step(self, batch, batch_idx):
        labels = batch["input_ids"]
        input_ids = batch["input_ids"]
        token_type_ids = None

        outputs = self.model(input_ids=input_ids, labels=labels, attention_mask=batch["attention_mask"])
        
        self.log('train_loss', outputs.loss)
        return outputs.loss
    
    def validation_step(self, batch, batch_idx):  
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["input_ids"]
        )
         
        self.log('validation_loss', outputs.loss)        
    
    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr = HYPER_PARAMS['lr'], eps = 1e-8)
        return optimizer

In [None]:
import wandb
import os
from pytorch_lightning.loggers import WandbLogger 

wandb.login()
wandb_logger = WandbLogger() 

In [None]:
wandb.init(project='GPT2 - Expace')
wandb.config.update(HYPER_PARAMS)

In [None]:
print('Training model with:')
print(json.dumps(HYPER_PARAMS,indent=True))

In [None]:
gpt2_expace = LMHeadModel(model, tokenizer)

trainer = pl.Trainer(gpus=1, max_epochs=HYPER_PARAMS['epochs'], val_check_interval=VAL_CHECK_INTERVAL, checkpoint_callback=False, logger=wandb_logger)
trainer.fit(gpt2_expace, train_dataloader, validation_dataloader)
model.save_pretrained(f'{MODELS_DIR}{HYPER_PARAMS["name"]}')

In [None]:
model.save_pretrained(f'{MODELS_DIR}gpt2-expace-v1')