In [None]:
!git clone https://github.com/muhwagua/color-bert.git

In [2]:
!pip install transformers

Successfully installed sacremoses-0.0.43 tokenizers-0.10.1 transformers-4.4.2


In [13]:
import random
import re
import urllib.request

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    DataCollatorForLanguageModeling
)
from argparse import Namespace

In [4]:
txt_url = "https://raw.githubusercontent.com/muhwagua/color-bert/main/data/all.txt"
urllib.request.urlretrieve(txt_url, 'train.txt')

('train.txt', <http.client.HTTPMessage at 0x7fb3452fbe50>)

In [5]:
args = Namespace()
args.train = "train.txt"
args.max_len = 128
args.model_name = "bert-base-uncased"
args.batch_size = 4
args.color_ratio = 0.5

In [6]:
tokenizer = BertTokenizer.from_pretrained(args.model_name)


class MaskedLMDataset(Dataset):
    def __init__(self, file, color_ratio, tokenizer, masking):
        self.tokenizer = tokenizer
        self.color_ratio = color_ratio
        self.masking = masking
        self.lines = self.load_lines(file)
        self.masked = self.all_mask(self.lines, self.color_ratio)
        self.ids = self.encode_lines(self.lines, self.masked, masking)

    def load_lines(self, file):
        with open(file) as f:
            lines = [
                line
                for line in f.read().splitlines()
                if (len(line) > 0 and not line.isspace())
            ]
        return lines

    def color_mask(self, line, masking=True):
        colors = [
        "red",
        "orange",
        "yellow",
        "green",
        "blue",
        "purple",
        "brown",
        "white",
        "black",
        "pink",
        "lime",
        "gray",
        "violet",
        "cyan",
        "magenta",
        "khaki",
    ]
        for color in colors:
            match = re.search(f"(\s|^){color}(\s|[.!?\\-])", line)
            if match:
                global start, end
                (start, end) = random.choice([match.span()])
        return line[: start + 1] + "[MASK]" + line[end - 1 :]
    
    def random_mask(self, line, masking=True):
        words = line.split()
        mask_idx = random.choice(range(len(words)))
        words[mask_idx] = "[MASK]"
        return " ".join(words)  

    def all_mask(self, lines, color_ratio, masking=True):
        masked = []
        for line in lines:
            coin = random.random()  
            if coin > color_ratio:
                masked.append(self.random_mask(line))
            else:
                masked.append(self.color_mask(line))
        
        return masked

    def encode_lines(self, lines, masked, masking):
        if masking == True:
            batch_encoding = self.tokenizer(
                masked, add_special_tokens=True, truncation=True, padding=True, 
                max_length=args.max_len
            )
            return batch_encoding["input_ids"]

        elif masking == False:
            batch_encoding = self.tokenizer(
                lines, add_special_tokens=True, truncation=True, padding=True, 
                max_length=args.max_len
            )
            return batch_encoding["input_ids"]


    def __len__(self):
        return len(self.lines)
    
    
    def __getitem(self, idx):
        return torch.tensor(self.ids[idx], dtype=torch.long)


train_dataset = MaskedLMDataset(args.train, args.color_ratio, tokenizer, masking=True)
label_dataset = MaskedLMDataset(args.train, args.color_ratio, tokenizer, masking=False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




In [8]:
train_dataset.lines[:10]

['look at those big gray eyes and that beautiful red hair!',
 'he held out a bottle of red water.',
 "both breathed hard, and she noticed a red slash across darian's face.",
 'the sun turned the dunes orange red and then quickly sank, leaving them in pre-moon darkness.',
 'stout, about the average height, broad, with huge red hands; he did not know, as the saying is, how to enter a drawing room and still less how to leave one; that is, how to say something particularly agreeable before going away.',
 'alex turned red when carmen stepped forward.',
 'the only bait he could find was a bright red blossom from a flower; but he knew fishes are easy to fool if anything bright attracts their attention, so he decided to try the blossom.',
 "what she expected to see when she turned was sarah's white plymouth, but the car that stopped before the house was allen's red eagle talon.",
 "don't forget the mess of red hair and freckles.",
 'she has on a pretty red dress.']

In [9]:
train_dataset.masked[:10]

['look [MASK] those big gray eyes and that beautiful red hair!',
 'he held out a bottle of red [MASK]',
 "both breathed hard, and she noticed a red slash across darian's [MASK]",
 'the sun turned the dunes [MASK] red and then quickly sank, leaving them in pre-moon darkness.',
 'stout, about the average height, broad, [MASK] huge red hands; he did not know, as the saying is, how to enter a drawing room and still less how to leave one; that is, how to say something particularly agreeable before going away.',
 'alex turned [MASK] when carmen stepped forward.',
 'the only bait he could find was a bright [MASK] blossom from a flower; but he knew fishes are easy to fool if anything bright attracts their attention, so he decided to try the blossom.',
 "what she expected to see when she turned was [MASK] white plymouth, but the car that stopped before the house was allen's red eagle talon.",
 "don't forget the mess of [MASK] hair and freckles.",
 'she has on a pretty [MASK] dress.']

In [10]:
train_dataset.lines[-10:]

['He had green eyes and a pure cat soul and wicked swipe and was a supreme cuddler',
 'Oh and last but not least my tennis record ils 4 3 Not bad as long as I ve won as many matches as I ve lost I d consider myself in the green',
 'She told him taking her red dress off as she walked towards the bed',
 'the brownie boyz vip brown nigger sank brown neck ash brown dawg and i brown trash plus mal are going on a road trip on thursday to baltimore washington d',
 'It seems that despite instilling remarkable fear in me whenever a porn pop up appears and not preventing anyone from entering the building who shouldn t having those boys in blue two floors below has its uses',
 'Whether it was by the lake on the shore or on the shore by the park or under the roof of my home I cannot remember but clearly in my memory I will always bare the vision of three red fish lying beside three perfectly round stones',
 'I wandered in a tortured myself for while found some white trainers with a blue stripe',
 

In [11]:
train_dataset.masked[-10:]

['He had [MASK] eyes and a pure cat soul and wicked swipe and was a supreme cuddler',
 'Oh and [MASK]but not least my tennis record ils 4 3 Not bad as long as I ve won as many matches as I ve lost I d consider myself in the green',
 'She told [MASK] taking her red dress off as she walked towards the bed',
 'the brownie boyz [MASK] brown nigger sank brown neck ash brown dawg and i brown trash plus mal are going on a road trip on thursday to baltimore washington d',
 'It seems that despite instilling remarkable fear in me whenever a porn pop up appears and not preventing anyone from entering the building who shouldn t having those boys in [MASK] two floors below has its uses',
 'Whether it was by the lake on the shore or on the shore by the park or under the roof of my home I cannot remember but clearly in my memory I will always bare the vision of three red fish lying beside three perfectly [MASK] stones',
 'I wandered in a tortured myself for while found some [MASK] trainers with a blu

In [12]:
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size)

label_loader = DataLoader(
    label_dataset,
    batch_size=args.batch_size)