In [1]:
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_cosine_with_hard_restarts_schedule_with_warmup
import os
from tqdm.auto import tqdm
import torch



In [2]:
class UnparallelDataset(Dataset):
    def __init__(self, path, data_dir='../data/classification/skoltech-jigsaw/'):
        super().__init__()
        
        data_path = os.path.join(data_dir, path)
        
        self.data_list = []
        self.eos = " <|endoftext|>"
        
        df = pd.read_csv(data_path, sep='\t', names=['labels', 'text'])
        df['labels'] = df.labels.apply(lambda x: 'toxic' if x == 1 else 'normal')
        
        for row in tqdm(df.iterrows(), desc=f'Reading {path}'):
            self.data_list.append(f'{row[1]["labels"]}: {row[1]["text"]}{self.eos}')
    
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, item):
        return self.data_list[item]
    

class ParaphraseDataset(Dataset):
    def __init__(self, path, data_dir='../data/paraphrase/'):
        super().__init__()
        
        data_path = os.path.join(data_dir, path)
        
        self.data_list = []
        self.eos = " <|endoftext|>"
        
        df = pd.read_csv(data_path, sep='\t')
        
        for row in tqdm(df.iterrows(), desc=f'Reading {path}'):
            self.data_list.append(f'paraphrase: {row[1]["source"]} >>> {row[1]["backtranslate"]}{self.eos}')
            
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, item):
        return self.data_list[item]

In [3]:
def get_data_loader(path, mode='unparallel'):
    if mode == 'unparallel':
        dataset = UnparallelDataset(path)
    elif mode == 'paraphrase':
        dataset = ParaphraseDataset(path)
    else:
        raise NotImplementedError('available mode: [unparallel, paraphrase]')
        
    loader = DataLoader(dataset, batch_size=1, shuffle=True)
    return loader

In [4]:
def train(epochs, loader, batch_size, tokenizer, model, device):
    
    batch_counter = 0
    sumloss = 0

    num_steps = epochs * len(loader)
    pb = tqdm(range(num_steps))
    
    for e in range(epochs):
        print(f'Epoch {e+1}')
        
        for step, txt in enumerate(loader):
            txt = torch.tensor(tokenizer.encode(txt[0]))
            txt = txt.unsqueeze(0).to(device)
            outputs = model(txt, labels=txt)
            loss, _ = outputs[:2]
            loss.backward()
            sumloss += loss.item()
            
            if step % batch_size == 0:
                batch_counter += 1
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()
                
            if batch_counter == 10:
                print(f'Total Loss: {sumloss}')
                batch_counter = 0
                sumloss = 0
                
            pb.update(1)
                
    return model


def save_model(model, name):
    print('saving model...')
    torch.save(model.state_dict(), f'{name}.pt')
    
    
def load_model():
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    return tokenizer, model

In [5]:
BATCH_SIZE = 128
EPOCHS = 3
LEARNING_RATE = 3e-5
WARMUP_STEPS = 300
MAX_SEQ_LEN = 128
MODEL_PATH = '../model/unparallel.pt'
DATA_FILE = 'train.txt'  # 'paraphrase_ref.csv'


TOKENIZER, MODEL = load_model()
LOADER = get_data_loader(DATA_FILE)

DEVICE = 'cuda:4' if torch.cuda.is_available() else 'cpu'

model = MODEL.to(DEVICE)
model.train()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=-1)

Reading train.txt: 0it [00:00, ?it/s]



In [6]:
model = train(EPOCHS, LOADER, BATCH_SIZE, TOKENIZER, MODEL, DEVICE)

  0%|          | 0/300000 [00:00<?, ?it/s]

Epoch 1
Total Loss: 6556.524174451828
Total Loss: 7024.49076461792
Total Loss: 6517.577111721039
Total Loss: 5871.979538440704
Total Loss: 5385.574621915817
Total Loss: 5006.7227239608765
Total Loss: 4820.105655789375
Total Loss: 4736.162583589554
Total Loss: 4616.006004571915
Total Loss: 4508.450771689415
Total Loss: 4485.800826787949
Total Loss: 4511.129070997238
Total Loss: 4407.7837607860565
Total Loss: 4490.502493619919
Total Loss: 4411.379061102867
Total Loss: 4390.180621385574
Total Loss: 4375.2236849069595
Total Loss: 4408.710793435574
Total Loss: 4413.848687887192
Total Loss: 4375.677514910698
Total Loss: 4349.256544589996
Total Loss: 4385.362689256668
Total Loss: 4315.6337769031525
Total Loss: 4336.66351890564
Total Loss: 4273.727068543434
Total Loss: 4336.140874028206
Total Loss: 4304.023709774017
Total Loss: 4283.8170664310455
Total Loss: 4289.819076538086
Total Loss: 4274.579474329948
Total Loss: 4317.073293685913
Total Loss: 4331.883008003235
Total Loss: 4311.820141077042

In [8]:
save_model(model, MODEL_PATH)

saving model...


### Generation

In [9]:
def choose_from_top_k_top_n(probs, k=50, p=0.8):
    ind = np.argpartition(probs, -k)[-k:]
    top_prob = probs[ind]
    top_prob = {i: top_prob[idx] for idx,i in enumerate(ind)}
    sorted_top_prob = {k: v for k, v in sorted(top_prob.items(), key=lambda item: item[1], reverse=True)}
    
    t=0
    f=[]
    pr = []
    for k,v in sorted_top_prob.items():
        t+=v
        f.append(k)
        pr.append(v)
        if t>=p:
            break
    top_prob = pr / np.sum(pr)
    token_id = np.random.choice(f, 1, p = top_prob)

    return int(token_id)

def load_models(model_name):
    print ('Loading Trained GPT-2 Model')
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
    model = GPT2LMHeadModel.from_pretrained('gpt2-medium')
    model_path = model_name
    model.load_state_dict(torch.load(model_path))
    return tokenizer, model

In [14]:
def generate(tokenizer, model, sentences, label):
    with torch.no_grad():
        for idx in range(sentences):
            finished = False
            cur_ids = torch.tensor(tokenizer.encode(label)).unsqueeze(0).to(DEVICE)
            for i in range(128):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]

                softmax_logits = torch.softmax(logits[0,-1], dim=0)

                if i < 5:
                    n = 10
                else:
                    n = 5

                next_token_id = choose_from_top_k_top_n(softmax_logits.to('cpu').numpy()) #top-k-top-n sampling
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(DEVICE) * next_token_id], dim = 1)

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    finished = True
                    break

            if finished:         
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)
                print (output_text)
            else:
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)
                print (output_text)

In [15]:
generate(TOKENIZER, model, 10, 'normal')

normal: your comments on trudeau are irrelevant. <|endoftext|>
normal: a good friend and friend of mine has a good job. <|endoftext|>
normal: he would take the initiative to give his party a good old fashioned look in the light of a progressive agenda. <|endoftext|>
normal: we had the war and it was better than the slavery. <|endoftext|>
normal: he will continue to pay the tax they just made. <|endoftext|>
normal: maybe in the future i am, but for now i want to learn, and work on, myself. <|endoftext|>
normal: its not a coincidence that many countries do not want a muslim invasion. <|endoftext|>
normal: to show how wrong i am with my own writing here. <|endoftext|>
normal: thats why the trudeau is a trudeau. <|endoftext|>
normal: the majority of the time she can be fairly described as a nazi sympathizer. <|endoftext|>


In [16]:
generate(TOKENIZER, model, 10, 'toxic')

toxic: he is just another corrupt politician. <|endoftext|>
toxic: that was a joke, and this is all, he is a fool. <|endoftext|>
toxic: im not even sure how you feel about that. <|endoftext|>
toxic: the americans need to be sent back to a land of no freedom of speech. <|endoftext|>
toxic: you are stupid. <|endoftext|>
toxic: when it comes to stupidity, that last bit matters most. <|endoftext|>
toxic: your stupidity and lack of common sense is the real problem with america. <|endoftext|>
toxic: i dont like this post, however, it seems alluding to the fact that, yes, i do love black people, it just seems so stupid. <|endoftext|>
toxic: you think theyll shoot you. <|endoftext|>
toxic: you are a big fucking fucker who has a bad attitude. <|endoftext|>


In [19]:
generate(TOKENIZER, model, 1, 'normal')

normal: and if theres no evidence to support it, then it is silly to argue. <|endoftext|>
