In [None]:
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

# Freeze all pre-trained parameters
for param in model.parameters():
    param.requires_grad = False

# Add a custom classification layer on top of BERT
class CustomBERTModel(nn.Module):
    def __init__(self, num_labels):
        super(CustomBERTModel, self).__init__()
        self.bert = model
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(768, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.fc(pooled_output)
        return logits

# Load text file and list of words to not predict
with open('text_file.txt', 'r') as f:
    text = f.read()

not_predict = ['word1', 'word2', 'word3']

# Tokenize text and convert to input tensors
tokens = tokenizer.tokenize(text)
tokens = [token if token not in not_predict else '[MASK]' for token in tokens]
input_ids = torch.tensor(tokenizer.convert_tokens_to_ids(tokens)).unsqueeze(0)
attention_mask = torch.ones_like(input_ids)

# Define model and optimizer
model = CustomBERTModel(num_labels=len(tokenizer))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# Fine-tune model
labels = input_ids.clone().detach()
labels[labels != tokenizer.mask_token_id] = -100

for epoch in range(10):
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    loss = outputs.loss
    logits = outputs.logits
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch:', epoch, 'Loss:', loss.item())

# Generate predictions for text with masked words
masked_input_ids = input_ids.clone().detach()
masked_input_ids[masked_input_ids == tokenizer.mask_token_id] = 0

outputs = model(input_ids=masked_input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=-1)

# Convert predictions back to tokens
predicted_tokens = tokenizer.convert_ids_to_tokens(predictions.squeeze())
predicted_text = tokenizer.convert_tokens_to_string(predicted_tokens)
