In [1]:
from IPython.display import clear_output

!pip install transformers
!pip install datasets
!pip install torchtext
!pip3 install tensorflow_text
!pip3 install urllib3==1.25.4


clear_output()

Insert here link to the data set or unzip the archive in the same directory with this colab notebook.

In [2]:
# !wget <> && unzip GYAFC_Corpus.zip
# clear_output()

Import libs. 

In [3]:
import os
import re
import torch
import pprint
import torch.nn as nn
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import AdamW, get_linear_schedule_with_warmup
from pathlib import Path

Create datasets.

In [14]:
train_formal, train_informal = [], []
valid_formal, valid_informal = [], []
test_formal, test_informal = [], []


with open('GYAFC_Corpus/Entertainment_Music/train/formal', 'r', encoding='utf-8') as file:
    train_formal = file.readlines()
with open('GYAFC_Corpus/Entertainment_Music/train/informal', 'r', encoding='utf-8') as file:
    train_informal = file.readlines()
    
with open('GYAFC_Corpus/Entertainment_Music/tune/formal.ref0', 'r', encoding='utf-8') as file:
    valid_formal = file.readlines()
with open('GYAFC_Corpus/Entertainment_Music/tune/informal', 'r', encoding='utf-8') as file:
    valid_informal = file.readlines()
    
with open('GYAFC_Corpus/Entertainment_Music/test/formal.ref0', 'r', encoding='utf-8') as file:
    test_formal = file.readlines()
with open('GYAFC_Corpus/Entertainment_Music/test/informal', 'r', encoding='utf-8') as file:
    test_informal = file.readlines()

train_formal = [re.sub('\n', '', x) for x in train_formal]
train_informal = [re.sub('\n', '', x) for x in train_informal]
valid_formal = [re.sub('\n', '', x) for x in valid_formal]
valid_informal = [re.sub('\n', '', x) for x in valid_informal]
test_formal = [re.sub('\n', '', x) for x in test_formal]
test_informal = [re.sub('\n', '', x) for x in test_informal]

delimiter = ' >>> '
train = [x + delimiter + y for x, y in zip(train_informal, train_formal)]
valid = [x + delimiter + y for x, y in zip(valid_informal, valid_formal)]
test = [x + delimiter + y for x, y in zip(test_informal, test_formal)]

In [15]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, texts):
        self.tokenizer = tokenizer
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):
        source_text = self.texts[index].split(' >>> ')[0]
        target_text = self.texts[index].split(' >>> ')[-1]
        source = self.tokenizer.encode_plus(
            source_text,
            max_length=100,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors='pt'
            )
        target = self.tokenizer.encode_plus(
            target_text,
            max_length=100,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors='pt'
            )

        source_ids = source['input_ids'].squeeze()
        source_mask = source['attention_mask'].squeeze()
        target_ids = target['input_ids'].squeeze()
        target_mask = target['attention_mask'].squeeze()

        return {
            'source_ids': source_ids.to(dtype=torch.long), 
            'source_mask': source_mask.to(dtype=torch.long), 
            'target_ids': target_ids.to(dtype=torch.long),
            'target_mask': target_mask.to(dtype=torch.long)
        }

In [22]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

trainset = MyDataset(tokenizer=tokenizer, texts=train)
testset = MyDataset(tokenizer=tokenizer, texts=test)
valset = MyDataset(tokenizer=tokenizer, texts=valid)

In [23]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

batch_size = 16
train_dataloader = DataLoader(trainset, sampler=SequentialSampler(trainset), batch_size=batch_size, num_workers=2)
validation_dataloader = DataLoader(valset, sampler=SequentialSampler(valset), batch_size=batch_size, num_workers=2)

In [24]:
from transformers import GPT2Config
import random 

configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained("gpt2", config=configuration)
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda")

seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

If you are not willing to train the model, skip the cells below

In [26]:
epochs = 3
learning_rate = 5e-6
warmup_steps = 5e2
epsilon = 1e-10

sample_every = 100

optimizer = AdamW(model.parameters(), lr=learning_rate, eps=epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=len(train_dataloader)*epochs)

In [27]:
from tqdm import tqdm

training_stats = []
model = model.to(device)

for epoch_i in tqdm(range(epochs)):

    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))

    total_train_loss = 0
    model.train()

    for batch in tqdm(train_dataloader):

        b_input_ids = batch['source_ids'].to(device)
        b_labels = batch['target_ids'].to(device)
        b_masks = batch['source_mask'].to(device)

        model.zero_grad()      
        optimizer.zero_grad()  
        outputs = model(b_input_ids, labels=b_labels, attention_mask=b_masks, token_type_ids=None)
        loss = outputs.loss  

        batch_loss = loss.item()
        total_train_loss += batch_loss

        loss.backward()
        optimizer.step()
        scheduler.step()

    avg_train_loss = total_train_loss/len(train_dataloader)       

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("Running Validation...")

    model.eval()

    total_eval_loss = 0
    nb_eval_steps = 0

    for batch in tqdm(validation_dataloader):
        
        b_input_ids = batch['source_ids'].to(device)
        b_labels = batch['target_ids'].to(device)
        b_masks = batch['source_mask'].to(device)
        
        with torch.no_grad():        
            
            outputs=model(b_input_ids, attention_mask=b_masks, labels=b_labels)
            loss=outputs[0]  
            
        batch_loss=loss.item()
        total_eval_loss+=batch_loss        

    avg_val_loss=total_eval_loss / len(validation_dataloader)
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
        }
    )


  0%|          | 0/3 [00:00<?, ?it/s][A

  0%|          | 0/3288 [00:00<?, ?it/s][A[A



[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
 24%|██▍       | 790/3288 [06:21<20:26,  2.04it/s][A[A

 24%|██▍       | 791/3288 [06:21<20:25,  2.04it/s][A[A

 24%|██▍       | 792/3288 [06:22<20:24,  2.04it/s][A[A

 24%|██▍       | 793/3288 [06:22<20:22,  2.04it/s][A[A

 24%|██▍       | 794/3288 [06:23<20:22,  2.04it/s][A[A

 24%|██▍       | 795/3288 [06:23<20:21,  2.04it/s][A[A

 24%|██▍       | 796/3288 [06:23<20:21,  2.04it/s][A[A

 24%|██▍       | 797/3288 [06:24<20:20,  2.04it/s][A[A

 24%|██▍       | 798/3288 [06:24<20:20,  2.04it/s][A[A

 24%|██▍       | 799/3288 [06:25<20:20,  2.04it/s][A[A

 24%|██▍       | 800/3288 [06:25<20:18,  2.04it/s][A[A

 24%|██▍       | 801/3288 [06:26<20:18,  2.04it/s][A[A

 24%|██▍       | 802/3288 [06:26<20:17,  2.04it/s][A[A

 24%|██▍       | 803/3288 [06:27<20:17,  2.04it/s][A[A

 24%|██▍       | 804/3288 [06:27<20:17,  2.04it/s][A[A

 24%|██▍       | 805/3288 [06:28<20:16,  2.04it/


  Average training loss: 1.58
Running Validation...




  1%|          | 1/180 [00:00<00:49,  3.58it/s][A[A

  1%|          | 2/180 [00:00<00:43,  4.11it/s][A[A

  2%|▏         | 3/180 [00:00<00:38,  4.60it/s][A[A

  2%|▏         | 4/180 [00:00<00:35,  5.02it/s][A[A

  3%|▎         | 5/180 [00:00<00:32,  5.35it/s][A[A

  3%|▎         | 6/180 [00:01<00:32,  5.36it/s][A[A

  4%|▍         | 7/180 [00:01<00:30,  5.66it/s][A[A

  4%|▍         | 8/180 [00:01<00:29,  5.85it/s][A[A

  5%|▌         | 9/180 [00:01<00:28,  6.00it/s][A[A

  6%|▌         | 10/180 [00:01<00:28,  6.04it/s][A[A

  6%|▌         | 11/180 [00:01<00:27,  6.18it/s][A[A

  7%|▋         | 12/180 [00:02<00:27,  6.03it/s][A[A

  7%|▋         | 13/180 [00:02<00:27,  6.08it/s][A[A

  8%|▊         | 14/180 [00:02<00:27,  6.12it/s][A[A

  8%|▊         | 15/180 [00:02<00:26,  6.14it/s][A[A

  9%|▉         | 16/180 [00:02<00:26,  6.20it/s][A[A

  9%|▉         | 17/180 [00:02<00:26,  6.22it/s][A[A

 10%|█         | 18/180 [00:03<00:26,  6.15it/s][A[A


  Validation Loss: 0.80


[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
 24%|██▍       | 790/3288 [06:22<20:29,  2.03it/s][A[A

 24%|██▍       | 791/3288 [06:23<20:27,  2.03it/s][A[A

 24%|██▍       | 792/3288 [06:23<20:30,  2.03it/s][A[A

 24%|██▍       | 793/3288 [06:24<20:33,  2.02it/s][A[A

 24%|██▍       | 794/3288 [06:24<20:29,  2.03it/s][A[A

 24%|██▍       | 795/3288 [06:25<20:28,  2.03it/s][A[A

 24%|██▍       | 796/3288 [06:25<20:27,  2.03it/s][A[A

 24%|██▍       | 797/3288 [06:26<20:27,  2.03it/s][A[A

 24%|██▍       | 798/3288 [06:26<20:25,  2.03it/s][A[A

 24%|██▍       | 799/3288 [06:27<20:26,  2.03it/s][A[A

 24%|██▍       | 800/3288 [06:27<20:28,  2.03it/s][A[A

 24%|██▍       | 801/3288 [06:28<20:26,  2.03it/s][A[A

 24%|██▍       | 802/3288 [06:28<20:26,  2.03it/s][A[A

 24%|██▍       | 803/3288 [06:29<20:26,  2.03it/s][A[A

 24%|██▍       | 804/3288 [06:29<20:26,  2.03it/s][A[A

 24%|██▍       | 805/3288 [06:30<20:24,  2.03it/


  Average training loss: 0.84
Running Validation...




  1%|          | 1/180 [00:00<00:52,  3.39it/s][A[A

  1%|          | 2/180 [00:00<00:45,  3.91it/s][A[A

  2%|▏         | 3/180 [00:00<00:39,  4.44it/s][A[A

  2%|▏         | 4/180 [00:00<00:35,  4.90it/s][A[A

  3%|▎         | 5/180 [00:00<00:33,  5.24it/s][A[A

  3%|▎         | 6/180 [00:01<00:33,  5.21it/s][A[A

  4%|▍         | 7/180 [00:01<00:31,  5.44it/s][A[A

  4%|▍         | 8/180 [00:01<00:30,  5.59it/s][A[A

  5%|▌         | 9/180 [00:01<00:29,  5.82it/s][A[A

  6%|▌         | 10/180 [00:01<00:28,  5.88it/s][A[A

  6%|▌         | 11/180 [00:01<00:28,  5.99it/s][A[A

  7%|▋         | 12/180 [00:02<00:28,  5.90it/s][A[A

  7%|▋         | 13/180 [00:02<00:28,  5.86it/s][A[A

  8%|▊         | 14/180 [00:02<00:28,  5.90it/s][A[A

  8%|▊         | 15/180 [00:02<00:27,  5.98it/s][A[A

  9%|▉         | 16/180 [00:02<00:26,  6.08it/s][A[A

  9%|▉         | 17/180 [00:02<00:26,  6.13it/s][A[A

 10%|█         | 18/180 [00:03<00:26,  6.01it/s][A[A


  Validation Loss: 0.77


[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
 24%|██▍       | 790/3288 [06:28<20:24,  2.04it/s][A[A

 24%|██▍       | 791/3288 [06:28<20:29,  2.03it/s][A[A

 24%|██▍       | 792/3288 [06:29<20:30,  2.03it/s][A[A

 24%|██▍       | 793/3288 [06:29<20:24,  2.04it/s][A[A

 24%|██▍       | 794/3288 [06:30<20:24,  2.04it/s][A[A

 24%|██▍       | 795/3288 [06:30<20:25,  2.03it/s][A[A

 24%|██▍       | 796/3288 [06:31<20:25,  2.03it/s][A[A

 24%|██▍       | 797/3288 [06:31<20:22,  2.04it/s][A[A

 24%|██▍       | 798/3288 [06:32<20:24,  2.03it/s][A[A

 24%|██▍       | 799/3288 [06:32<20:24,  2.03it/s][A[A

 24%|██▍       | 800/3288 [06:33<20:21,  2.04it/s][A[A

 24%|██▍       | 801/3288 [06:33<20:20,  2.04it/s][A[A

 24%|██▍       | 802/3288 [06:34<20:20,  2.04it/s][A[A

 24%|██▍       | 803/3288 [06:34<20:23,  2.03it/s][A[A

 24%|██▍       | 804/3288 [06:35<20:25,  2.03it/s][A[A

 24%|██▍       | 805/3288 [06:35<20:21,  2.03it/


  Average training loss: 0.82
Running Validation...




  1%|          | 1/180 [00:00<00:49,  3.60it/s][A[A

  1%|          | 2/180 [00:00<00:43,  4.10it/s][A[A

  2%|▏         | 3/180 [00:00<00:38,  4.60it/s][A[A

  2%|▏         | 4/180 [00:00<00:35,  5.00it/s][A[A

  3%|▎         | 5/180 [00:00<00:33,  5.27it/s][A[A

  3%|▎         | 6/180 [00:01<00:32,  5.36it/s][A[A

  4%|▍         | 7/180 [00:01<00:30,  5.62it/s][A[A

  4%|▍         | 8/180 [00:01<00:29,  5.82it/s][A[A

  5%|▌         | 9/180 [00:01<00:29,  5.87it/s][A[A

  6%|▌         | 10/180 [00:01<00:28,  5.93it/s][A[A

  6%|▌         | 11/180 [00:01<00:27,  6.05it/s][A[A

  7%|▋         | 12/180 [00:02<00:28,  5.98it/s][A[A

  7%|▋         | 13/180 [00:02<00:27,  6.00it/s][A[A

  8%|▊         | 14/180 [00:02<00:27,  6.03it/s][A[A

  8%|▊         | 15/180 [00:02<00:27,  6.03it/s][A[A

  9%|▉         | 16/180 [00:02<00:27,  6.07it/s][A[A

  9%|▉         | 17/180 [00:02<00:26,  6.12it/s][A[A

 10%|█         | 18/180 [00:03<00:26,  6.10it/s][A[A


  Validation Loss: 0.76





In [28]:
training_stats

[{'Training Loss': 1.5802327106070055,
  'Valid. Loss': 0.8018712182839711,
  'epoch': 1},
 {'Training Loss': 0.8437127087473725,
  'Valid. Loss': 0.7708918824791908,
  'epoch': 2},
 {'Training Loss': 0.8197149884617386,
  'Valid. Loss': 0.7624093878600332,
  'epoch': 3}]

Create results txt file. Run from here if you want only to use pretrained model to calculate results. 

In [44]:
#download the model
# !wget https://www.dropbox.com/s/rdy9j2mjthdx16s/gpt_music.zip  && unzip gpt_music.zip

In [45]:
configuration = GPT2Config.from_pretrained('gpt_music', output_hidden_states=False)
model = GPT2LMHeadModel.from_pretrained('gpt_music', config=configuration)
model.resize_token_embeddings(len(tokenizer))

tokenizer = GPT2Tokenizer.from_pretrained('gpt_music')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

0

In [34]:
import tqdm 
from IPython.display import clear_output

model.eval().cuda()

res = []
for sentence in tqdm.tqdm(test):
    gen = model.generate(
        input_ids=tokenizer.encode(sentence.split(' >>> ')[0], return_tensors='pt').cuda(),
        do_sample=True,
        top_k=10,
        top_p=0.1,
        temperature=10.,
        num_return_sequences=1,
        repetition_penalty=10.,
        max_length=len(sentence.split(' >>> ')[-1]) + 10
        )
    
    decoded = tokenizer.decode(gen[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
    res.append(decoded.split(' >>> ')[-1])
    clear_output()


100%|██████████| 1416/1416 [14:38<00:00,  1.61it/s]


In [36]:
with open('results_gpt_music.txt', 'w') as file:
    file.write('\n'.join(x for x in res))

Inference.


In [35]:
#!wget https://www.dropbox.com/s/o4a53s6p8rwyj0j/clf_family.zip && unzip clf_family.zip
!wget https://www.dropbox.com/s/7xopvppe3yyh8ft/clf_music.zip && unzip clf_music.zip
!git clone https://github.com/maxs-kan/text_style_transfer.git
!pip install bert_embedding
!pip install numpy --upgrade
clear_output()

In [37]:
!python ./text_style_transfer/metric/evaluate.py -i ./GYAFC_Corpus/Entertainment_Music/test/informal -p ./results_gpt_music.txt --tox_classifier_path ./clf_music_lower/

2021-05-26 14:49:21.051165: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
Vocab file is not found. Downloading.
Downloading /root/.mxnet/models/book_corpus_wiki_en_uncased-a6607397.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/book_corpus_wiki_en_uncased-a6607397.zip...
Downloading /root/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip...
1416
1416
--------------------
Calculating Style Transfer Accuracy
100% 89/89 [00:45<00:00,  1.96it/s]
Calculating BLEU similarity
Calculating EMB similarity
100% 45/45 [02:52<00:00,  3.84s/it]
100% 45/45 [02:51<00:00,  3.82s/it]
Calculating token-level perplexity
Downloading: 100% 718/718 [00:00<00:00, 660kB/s]
Downloading: 100% 1.52G/1.52G [00:26<00:00, 58.1MB/s]
Downloading: 100% 1.04M

In [38]:
import codecs
import pandas as pd

with codecs.open('./results.md', 'r', encoding='utf-8') as inp:
    res = [s.strip('\n') for s in inp.readlines()]

m_res = res[2].split('|')
d = {'STA': m_res[1], 'CS':m_res[2], 'BLEU':m_res[3], 'PPL':m_res[4], 'GM':m_res[5]}
pd.DataFrame(data=d, index=[0])

Unnamed: 0,STA,CS,BLEU,PPL,GM
0,0.21,0.96,0.92,187.77,0.1
