In [None]:
from transformers import BertTokenizer, DataCollatorForLanguageModeling
from transformers import BertForMaskedLM
from datasets import Dataset
import pandas as pd
import pickle 
from torch.utils.data import DataLoader
import torch
from transformers import AdamW

In [None]:
directions  = pd.read_csv("directions.csv", index_col=0)
directions.dropna(inplace=True)
directions = directions.values

In [None]:
corpus = pd.read_csv("data/corpus_word_embedding.csv", index_col=0)["original cooking steps"]
corpus = list(corpus.values)

In [None]:
#directions.extend(corpus)
directions = pd.DataFrame(directions)

In [None]:
directions.columns = ["Directions"]

In [None]:
def split_text(text):
    text = str(text)
    return text.split()

In [None]:
directions = directions["Directions"].apply(split_text)

In [None]:
directions = pd.DataFrame(directions)

In [None]:
directions.dropna(inplace=True)

In [None]:
directions.values

In [None]:
texts = [' '.join(sentence) for sentence in directions.values.flatten()]

In [None]:
texts

In [None]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.train()
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# Tokenize the corpus
inputs = tokenizer(texts, return_tensors='pt', max_length=512, truncation=True, padding="max_length")

In [None]:
class MyDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

dataset = dataset = MyDataset(inputs)

data_loader = DataLoader(dataset, batch_size=128, collate_fn=data_collator)

optimizer = AdamW(model.parameters(), lr=5e-5)

for epoch in range(3): 
    for batch in data_loader:
        outputs = model(**{k: v.to(model.device) for k, v in batch.items()})
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    print(f"Epoch: {epoch}, Loss: {loss.item()}")