In [None]:
!pip install transformers

In [None]:
import torch
import torch.nn as nn
import transformers
import random

In [None]:
# Define a generator model using BERT
class TextGenerator(nn.Module):
    def __init__(self, bert_model, hidden_dim):
        super(TextGenerator, self).__init__()
        self.bert = bert_model
        self.fc = nn.Linear(bert_model.config.hidden_size, hidden_dim)
        self.relu = nn.ReLU()
        self.decoder = nn.Linear(hidden_dim, bert_model.config.vocab_size)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        hidden = self.fc(pooled_output)
        hidden = self.relu(hidden)
        logits = self.decoder(hidden)
        return logits

In [None]:
# Define a loss function that encourages anomalies
class AnomalyLoss(nn.Module):
    def __init__(self):
        super(AnomalyLoss, self).__init__()

    def forward(self, logits, target_logits):
        # Define a custom loss function that encourages deviations from normal patterns
        loss = torch.nn.functional.mse_loss(logits, target_logits)
        return loss

In [None]:
# Implement anomaly simulation (random word replacement)
def simulate_anomaly(input_text):
    # Split the input text into words
    words = input_text.split()

    # Choose a random word to replace (excluding the first and last words)
    word_index_to_replace = random.randint(1, len(words) - 2)

    # Generate a random replacement word
    replacement_word = generate_random_word()

    # Replace the selected word with the replacement word
    words[word_index_to_replace] = replacement_word

    # Join the words back into a single string
    text_with_anomaly = ' '.join(words)

    return text_with_anomaly

# Function to generate a random word (you can customize this)
def generate_random_word():
    # This is a simple example; you can replace this with more sophisticated logic
    random_word = ''.join(random.choice('abcdefghijklmnopqrstuvwxyz') for _ in range(random.randint(1, 10)))
    return random_word

In [None]:
# Initialize the BERT tokenizer
tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
# Pretrained BERT model
bert_model = transformers.BertModel.from_pretrained('bert-base-uncased')

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
# Initialize the text generator model
generator = TextGenerator(bert_model, hidden_dim=256)

In [None]:
# Define the optimizer and loss function
optimizer = torch.optim.Adam(generator.parameters(), lr=0.001)
anomaly_loss = AnomalyLoss()

In [None]:
# Define a dataset of normal text data (you should replace this with your data)
normal_text_data = [
    "This is a normal sentence.",
    "Anomaly detection is important for data security.",
    "The quick brown fox jumps over the lazy dog."
]

# toy dataset
# create masking pattern
# regenaration

# implement draem paper
# directly use data masking pattern
#

In [None]:
# Training loop for the generator
num_epochs = 10
for epoch in range(num_epochs):
    for text in normal_text_data:
        # Convert text to input tensors (using BERT tokenizer)
        inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        # Generate pseudo anomalies
        with torch.no_grad():
            target_logits = generator(input_ids, attention_mask)

        # Add anomalies to the text (e.g., random word replacement)
        text_with_anomalies = simulate_anomaly(text)

        # Convert text with anomalies to input tensors
        inputs_with_anomalies = tokenizer(text_with_anomalies, return_tensors='pt', padding=True, truncation=True)
        input_ids_with_anomalies = inputs_with_anomalies['input_ids']
        attention_mask_with_anomalies = inputs_with_anomalies['attention_mask']

        # Forward pass and loss computation
        optimizer.zero_grad()
        logits = generator(input_ids_with_anomalies, attention_mask_with_anomalies)
        loss = anomaly_loss(logits, target_logits)


        # Print the normal and anomalous data along with the loss
        print(f"Normal Text: {text}")
        print(f"Anomalous Text: {text_with_anomalies}")
        print(f"Loss: {loss.item()}")
        print("\n\n\n")
        # Backpropagation and optimization
        loss.backward()
        optimizer.step()

        # Display progress or log loss values

# After training, you can use the generator to produce pseudo anomalies


Normal Text: This is a normal sentence.
Anomalous Text: This ixxafgj a normal sentence.
Loss: 0.0004055955505464226




Normal Text: Anomaly detection is important for data security.
Anomalous Text: Anomaly detection is important for pbgaopu security.
Loss: 8.876712672645226e-05




Normal Text: The quick brown fox jumps over the lazy dog.
Anomalous Text: The quick brown fox jumps over xsx lazy dog.
Loss: 2.084578045469243e-05




Normal Text: This is a normal sentence.
Anomalous Text: This rqmshkk a normal sentence.
Loss: 0.001122341025620699




Normal Text: Anomaly detection is important for data security.
Anomalous Text: Anomaly detection is ysra for data security.
Loss: 0.003141367109492421




Normal Text: The quick brown fox jumps over the lazy dog.
Anomalous Text: The qdyqvnfo brown fox jumps over the lazy dog.
Loss: 7.670503805456974e-07




Normal Text: This is a normal sentence.
Anomalous Text: This is zfhsx normal sentence.
Loss: 1.5098141830094391e-06




Normal Text: Anom