In [None]:
from transformers import BertTokenizer, EncoderDecoderModel, BertGenerationEncoder, BertGenerationConfig, BertGenerationDecoder, BertGenerationTokenizer
import transformers
import torch
from torch import nn
from torch.utils.data import DataLoader
from datasets import load_dataset,concatenate_datasets

autoencoder_checkpoint = "../checkpoints/latent_space_representation_edit_model_wang_controllable_2019/autoencoder_3500_0.008632369404399974.pth"
classifier_checkpoint = "../checkpoints/latent_space_representation_edit_model_wang_controllable_2019/classifier_3500_0.04627384876884106.pth"
MAX_SEQUENCE_LENGTH = 60
batch_size = 2
skip_special_tokens = True
device = "cuda" if torch.cuda.is_available() else "cpu"
MASKED = 0
NOT_MASKED = 1
LEGAL_TEXT_LABEL = 1
NON_LEGAL_TEXT_LABEL = 0




# Load models

In [None]:
class Classifier(nn.Module):
    """
    Based on the code from @wang_controllable_2019
    """

    def __init__(self, latent_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, 100)
        self.relu1 = nn.LeakyReLU(
            0.2,
        )
        self.fc2 = nn.Linear(100, 50)
        self.relu2 = nn.LeakyReLU(0.2)
        self.fc3 = nn.Linear(50, output_size)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, input):
        out = self.fc1(input)
        out = self.relu1(out)
        out = self.fc2(out)
        out = self.relu2(out)
        out = self.fc3(out)
        out = self.sigmoid(out)
        
        return out  # batch_size * label_size


In [None]:
classifier = torch.load(classifier_checkpoint, map_location=torch.device(device))
autoencoder = torch.load(autoencoder_checkpoint, map_location=torch.device(device))
tokenizer = BertTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")

In [None]:
MAX_SEQUENCE_LENGTH = autoencoder.config.max_length
MAX_SEQUENCE_LENGTH

# Prepare dataset

In [None]:
def prepare_dataset(tokenizer, max_sequence_length, num_proc=10):
    assert tokenizer != None                                              
    legal_text = load_dataset(                                        
        "pierreguillou/lener_br_finetuning_language_model", streaming=False             
    )                                                                                 
    legal_text = legal_text.map(                                                
        lambda x: {"label": [LEGAL_TEXT_LABEL] * len(x["text"])}, num_proc=num_proc, batched=True
    )   

    
    
    wikipedia = load_dataset("jvanz/portuguese_wikipedia_sentences")                           
    wikipedia = wikipedia.map(                                        
        lambda x: {"label": [NON_LEGAL_TEXT_LABEL] * len(x["text"])}, num_proc=num_proc, batched=True
    )                                                                         
        
    train_dataset = concatenate_datasets([wikipedia["train"], legal_text["train"]])
    evaluation_dataset = concatenate_datasets(                              
        [wikipedia["evaluation"], legal_text["validation"]]
    )                                                                                                                 
    legal_text["train"] = train_dataset                   
    legal_text["evaluation"] = evaluation_dataset 
    
    legal_text = legal_text.map(                                         
        lambda x: tokenizer(                                
            x["text"],                                                
            add_special_tokens=False,                     
            #padding="max_length",                                                       
            #truncation=True,                     
            #max_length=MAX_SEQUENCE_LENGTH,
        ),
        num_proc=num_proc,
        batched=True,
    )
    legal_text = legal_text.map(lambda x: {"tokens_count": len(x["input_ids"])})

    legal_text.set_format(
        type="torch",
        columns=["input_ids", "token_type_ids", "attention_mask", "is_legal", "tokens_count"],
    )
    legal_text = legal_text.shuffle(seed=42)
    print(legal_text)
    return legal_text


datasets = prepare_dataset(tokenizer, MAX_SEQUENCE_LENGTH)
legal_dataset = datasets["evaluation"].filter(lambda x: x["label"] == 1).filter(lambda x: x["tokens_count"] <= 65 and x["tokens_count"] >=20)

# Controllable Unsupervised Text Attribute Transfer via Editing Entangled Latent Representation

```
@inproceedings{DBLP:journals/corr/abs-1905-12926,
  author    = {Ke Wang and Hang Hua and Xiaojun Wan},
  title     = {Controllable Unsupervised Text Attribute Transfer via Editing Entangled Latent Representation},
  booktitle = {NeurIPS},
  year      = {2019}
}
```

In [None]:
sentence = "Infere-se dos elementos coligidos que o agravante encontra-se na iminência de sofrer nova sanção, em decorrência do não pagamento da multa aplicada devido à retirada do tamponamento no hidrômetro, por um dos condôminos."
#sentence = "Esse é uma sentença para testar o modelo e ver se ele consegue recriar o texto. Ou seja, ver se a autodecodificação funciona."
tokenizer_output = tokenizer(sentence, 
    add_special_tokens=False, 
    padding="max_length", 
    truncation=True,
    max_length=MAX_SEQUENCE_LENGTH,
    return_tensors="pt")

print(tokenizer_output)
outputs = autoencoder.generate(tokenizer_output.input_ids)
print(tokenizer.decode(outputs[0]))

In [None]:
autoencoder.eval()
classifier.eval()

autoencoder_output = autoencoder(input_ids=tokenizer_output.input_ids, attention_mask=tokenizer_output.attention_mask, labels=tokenizer_output.input_ids, output_attentions=True, output_hidden_states=True)

data = autoencoder_output.encoder_last_hidden_state.detach()
encoder_outputs = transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
    last_hidden_state=data, 
    attentions=autoencoder_output.encoder_attentions
)
output = autoencoder.generate(encoder_outputs=encoder_outputs)
print(tokenizer.batch_decode(output, skip_special_tokens=True))  

In [None]:
autoencoder_output = autoencoder(input_ids=tokenizer_output.input_ids, attention_mask=tokenizer_output.attention_mask, labels=tokenizer_output.input_ids, output_attentions=True, output_hidden_states=True)

data = autoencoder_output.encoder_last_hidden_state.detach() * torch.rand(autoencoder_output.encoder_last_hidden_state.size())
encoder_outputs = transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
    last_hidden_state=data, 
    attentions=autoencoder_output.encoder_attentions
)
output = autoencoder.generate(encoder_outputs=encoder_outputs)
print(tokenizer.batch_decode(output, skip_special_tokens=True)) 

In [None]:
def fast_gradient_iterative_modification(autoencoder, classifier, autoencoder_output, target, tokenizer):
    """
    Fast Gradient Iterative Methods
    
    Based on the code from @wang_controllable_2019
    """
    classifier_loss = nn.BCELoss(reduction='mean') 
    for epsilon in [20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]:
        data = autoencoder_output.encoder_last_hidden_state.clone()
        for _ in range(1, 5):
            classifier.zero_grad()                                                                                                                                                                                                                
            print("epsilon:", epsilon)                                                                                                                                                                                                       
            data = data.clone()
            #print(f"data: {data}")
            # Set requires_grad attribute of tensor. Important for Attack                                                                                                                                                                    
            data.retain_grad()
            output = classifier(torch.sum(data, dim=1))
            # Calculate gradients of model in backward pass                                                                                                                                                                                  
            loss = classifier_loss(output[:,-1], target)
            print(f"Loss: {loss}")
            loss.backward(retain_graph=True)     
            data_grad = data.grad.data
            #print(f"data_grad size: {data_grad.size()}")
            #print(f"data_grad: {data_grad}")
            #print(f"epsilon * data_grad: {epsilon * data_grad}")
            data = data - epsilon * data_grad
            #print(f"data: {data}")
            epsilon = epsilon * 0.9

            encoder_outputs = transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions(
                last_hidden_state=data.detach(), 
                attentions=autoencoder_output.encoder_attentions
            )
            output = autoencoder.generate(encoder_outputs=encoder_outputs)
            print(tokenizer.decode(output[0], skip_special_tokens=True)) 



autoencoder.eval()
classifier.eval()

#autoencoder_output = autoencoder(input_ids=tokenizer_output.input_ids, attention_mask=tokenizer_output.attention_mask, labels=tokenizer_output.input_ids, output_attentions=True, output_hidden_states=True)

#print(f"Original sentence: {tokenizer.decode(tokenizer_output['input_ids'][0], skip_special_tokens=True)}")
#autoencoder_output = autoencoder(input_ids=tokenizer_output.input_ids, attention_mask=tokenizer_output.attention_mask, labels=tokenizer_output.input_ids, output_attentions=True, output_hidden_states=True)
#fast_gradient_iterative_modification(autoencoder, classifier, autoencoder_output, torch.ones(1), tokenizer)

for sample in legal_dataset:
    input_ids = sample["input_ids"].unsqueeze(0)
    print(f"Original sentence: {tokenizer.decode(input_ids[0], skip_special_tokens=True)}")
    attention_mask = sample["attention_mask"].unsqueeze(0)
    autoencoder_output = autoencoder.forward(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, output_attentions=True, output_hidden_states=True)
    fast_gradient_iterative_modification(autoencoder, classifier, autoencoder_output, torch.zeros(NON_LEGAL_TEXT_LABEL), tokenizer)
    print("-" * 50)
    print("\n")


```python
#generated_text = generate_text(tokenizer, autoencoder, data)
#print("| It {:2d} | classifier model pred {:5.4f} |".format(it, output[0].item()))                                                                                                                                                      
#print(generated_text)


def generate_text(tokenizer, decoder, latent_space): 
    decoder_input_ids = torch.full((latent_space.size(0), 1), tokenizer.cls_token_id )
    for position in range(MAX_SEQUENCE_LENGTH):
        attention_mask[:, position] = NOT_MASKED
        output = decoder(decoder_input_ids, encoder_hidden_states=latent_space)
        probabilities = torch.nn.functional.log_softmax(output.logits, dim=-1)
        _, next_word = torch.max(probabilities, dim=-1)        
        next_word = next_word[:,position]
        decoder_input_ids[:, position] = next_word   
    text = tokenizer.batch_decode(decoder_input_ids, skip_special_tokens=skip_special_tokens)
    return text
```