<a href="https://colab.research.google.com/github/hjesse92/style_transfer_w266/blob/main/Text_Style_Disentanglement_BERT_GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install -q transformers sentencepiece rouge_score evaluate

In [None]:
!nvidia-smi

Wed Mar 22 08:41:39 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   42C    P0    54W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch

if torch.cuda.is_available():     
    device = torch.device("cuda")
    print('Number of GPU(s) available:', torch.cuda.device_count())
    print('GPU device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available')
    device = torch.device("cpu")

Number of GPU(s) available: 1
GPU device name: NVIDIA A100-SXM4-40GB


In [None]:
from logging import warning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from transformers import GPT2Tokenizer, GPT2LMHeadModel, DistilBertTokenizer, DistilBertModel, AdamW

import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings('ignore')

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7f1cf51a0250>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

train_file = 'drive/MyDrive/data/original-train.tsv'
dev_file = 'drive/MyDrive/data/original-dev.tsv'
test_file = 'drive/MyDrive/data/original-test.tsv'
# df_train = pd.read_csv(train_file, sep='\t')
# df_dev = pd.read_csv(dev_file, sep='\t')
# df_test = pd.read_csv(test_file, sep='\t')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
torch.cuda.empty_cache()

In [None]:
class TextStyleTransferDataset(Dataset):
    def __init__(self, data_path, encoder_tokenizer, decoder_tokenizer, max_len):
        self.data = []
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.max_len = max_len

        self.decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
        self.decoder_tokenizer.padding_side = 'left'

        with open(data_path, 'r') as f:
            for line in f:
                source, target = line.strip().split('\t')
                self.data.append((source, target))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        source, target = self.data[idx]

        source_encoding = self.encoder_tokenizer(source, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt')
        target_encoding = self.encoder_tokenizer(target, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt')
        gpt_encoding = self.decoder_tokenizer(source, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt')
        target_gpt_encoding = self.decoder_tokenizer(target, max_length=self.max_len, padding='max_length', truncation=True, return_tensors='pt')

        return {'source_inputs': source_encoding,
                'gpt_inputs': gpt_encoding,
                'target_inputs': target_encoding,
                'target_gpt_inputs': target_gpt_encoding}

class TextStyleTransferModel(nn.Module):
    def __init__(self,):
        super(TextStyleTransferModel, self).__init__()
        
        self.bert_model_toxic = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.bert_model_nontoxic = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.gpt_model = GPT2LMHeadModel.from_pretrained("gpt2")
        
        # We will use autoencoder architecture to deconstruct toxic/nontoxic latent spaces and reconstruct
        self.toxic_encoder = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )
        self.toxic_decoder = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 768)
        )

        self.nontoxic_encoder = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )
        self.nontoxic_decoder = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 768)
        )
        self.style_transfer_decoder = nn.Sequential(
            nn.Linear(4, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 768)
        )



        # self.toxic_semantic = nn.Linear(in_features = self.bert_model_toxic.config.hidden_size,
        #                                 out_features = semantic_latent_dim,
        #                                 bias=True)
        # self.toxic_style = nn.Linear(in_features = self.bert_model_toxic.config.hidden_size,
        #                              out_features = style_latent_dim,
        #                              bias=True)
        
        # self.nontoxic_semantic = nn.Linear(in_features = self.bert_model_nontoxic.config.hidden_size,
        #                                    out_features = semantic_latent_dim,
        #                                    bias=True)
        # self.nontoxic_style = nn.Linear(in_features = self.bert_model_nontoxic.config.hidden_size,
        #                                 out_features = style_latent_dim,
        #                                 bias=True)

        self.fusion = nn.Linear(in_features = 4,
                                out_features = 4,
                                bias=True)

        # self.hidden1 = nn.Linear(768, 1536, bias=True)
        # # self.hidden2 = nn.Linear(self.hidden1.in_features, int(5*self.hidden1.out_features))

        # self.final_layer = nn.Linear(in_features = self.style_transfer_decoder.out_features,
        #                              out_features = self.bert_model_toxic.config.vocab_size,
        #                              bias=True)
        
        # self.mapping_hidden = nn.Linear(self.bert_model_toxic.config.vocab_size, self.gpt_model.config.vocab_size)
        self.mapping_model = nn.Linear(768, self.gpt_model.config.vocab_size)

        # This will be learned and saved in the model
        self.nontoxic_latent_style = None


    def forward(self, toxic_input_ids, toxic_attention_mask, nontoxic_input_ids=None, nontoxic_attention_mask=None):
        # Encoding
        toxic_bert = self.bert_model_toxic(input_ids=toxic_input_ids, attention_mask=toxic_attention_mask).last_hidden_state
        nontoxic_bert = self.bert_model_nontoxic(input_ids=nontoxic_input_ids, attention_mask=nontoxic_attention_mask).last_hidden_state
        
        # Grab the CLS token or the mean of last layer
        bert_pool_toxic = toxic_bert.mean(dim=1)
        bert_pool_nontoxic = nontoxic_bert.mean(dim=1)
        # content_latent = content_outputs[0]

        toxic_latent_rep = self.toxic_encoder(bert_pool_toxic)
        toxic_latent_semantic, toxic_latent_style = toxic_latent_rep[:, :2], toxic_latent_rep[:, 2:]

        nontoxic_latent_rep = self.nontoxic_encoder(bert_pool_nontoxic)
        nontoxic_latent_semantic, self.nontoxic_latent_style = nontoxic_latent_rep[:, :2], nontoxic_latent_rep[:, 2:]

        toxic_reconstructed = self.toxic_decoder(toxic_latent_rep)
        nontoxic_reconstructed = self.nontoxic_decoder(nontoxic_latent_rep)

        fusion_latent = torch.cat((toxic_latent_semantic, self.nontoxic_latent_style),dim=1)
        fusion_latent = self.fusion(fusion_latent)

        style_transfer_reconstructed = self.style_transfer_decoder(fusion_latent)

        mapping = self.mapping_model(style_transfer_reconstructed)
        
        mapped_embeddings = self.gpt_model.transformer.wte(mapping.argmax(-1))
        mapping_repeated = mapped_embeddings.unsqueeze(1).repeat(1, 128, 1)
        
        gpt_outputs = self.gpt_model(inputs_embeds=mapping_repeated, return_dict=True)
        gpt_logits = gpt_outputs.logits

        return  gpt_logits, \
                bert_pool_toxic, \
                bert_pool_nontoxic, \
                toxic_reconstructed, \
                nontoxic_reconstructed, \
                toxic_latent_semantic, \
                nontoxic_latent_semantic, \
                toxic_latent_style, \
                self.nontoxic_latent_style
        # Map the latent representations of toxic/non-toxic semantic/style



        # toxic_semantic_latent = self.toxic_semantic(bert_pool_toxic)
        # toxic_style_latent = self.toxic_style(bert_pool_toxic)

        # # nontoxic_semantic_latent = self.toxic_semantic(bert_pool_nontoxic)
        # self.nontoxic_style_latent = self.toxic_style(bert_pool_nontoxic)

        # # This is the key step where we fuse the original semantic with non-toxic style
        # fusion_latent = torch.cat((toxic_semantic_latent, self.nontoxic_style_latent),dim=1)
        # fusion_latent = self.fusion(fusion_latent)

        # Pass through some dense layers to generate tokens for GPT decoder
        # h1 = self.hidden1(fusion_latent)
        # h1 = nn.ReLU()(h1)
        # h1 = nn.Dropout(0.2)(h1)

        # h2 = self.hidden2(h1)
        # h2 = nn.ReLU()(h2)
        # h2 = nn.Dropout(0.2)(h2)
        
        # mapping = self.final_layer(h1)

        # Mapping BERT vocab ids into GPT2 vocab ids
        # mapping = self.mapping_hidden(mapping)
        # mapping = nn.ReLU()(mapping)
        # mapping = self.mapping_model(mapping)

        # mapped_embeddings = self.gpt_model.transformer.wte(mapping.argmax(-1))
        # mapping_repeated = mapped_embeddings.unsqueeze(1).repeat(1, 128, 1)
        
        # gpt_outputs = self.gpt_model(inputs_embeds=mapping_repeated, return_dict=True)
        # gpt_logits = gpt_outputs.logits
        
        # return gpt_logits, (toxic_semantic_latent, toxic_style_latent, nontoxic_semantic_latent, self.nontoxic_style_latent)

    def generate(self, toxic_input_ids, toxic_attention_mask, decoding_kwargs=None):
        '''
        Generating text based on toxic inputs and learned-nontoxic_style
        '''
        if not decoding_kwargs:
            decoding_kwargs = {
              "num_return_sequences": 1,
              "top_p": 0.9,  # For nucleus sampling
              "top_k": 50,  # For top-k sampling
              "temperature": 0.8,  # For temperature sampling
              "do_sample": True,  # Set to True for sampling-based methods
              "max_length": 128,  # The maximum length of the generated text
              "eos_token_id": self.gpt_model.config.eos_token_id
          }

        with torch.no_grad():
            # Encoding
            toxic_bert = self.bert_model_toxic(input_ids=toxic_input_ids, attention_mask=toxic_attention_mask).last_hidden_state

            # Grab the CLS token or the mean of last layer
            bert_pool_toxic = toxic_bert.mean(dim=1)

            toxic_latent_rep = self.toxic_encoder(bert_pool_toxic)
            toxic_latent_semantic, toxic_latent_style = toxic_latent_rep[:, :2], toxic_latent_rep[:, 2:]

            fusion_latent = torch.cat((toxic_latent_semantic, self.nontoxic_latent_style),dim=1)
            fusion_latent = self.fusion(fusion_latent)

            style_transfer_reconstructed = self.style_transfer_decoder(fusion_latent)

            mapping = self.mapping_model(style_transfer_reconstructed)

            mapped_embeddings = self.gpt_model.transformer.wte(mapping.argmax(-1))
            mapping_repeated = mapped_embeddings.unsqueeze(1).repeat(1, 128, 1)

            gpt_outputs = self.gpt_model.generate(inputs_embeds=mapping_repeated, **decoding_kwargs)

        return gpt_outputs

class StyleTransferLoss(nn.Module):
    def __init__(self):
        super(StyleTransferLoss, self).__init__()
    
    def forward(self, 
                output,
                target,
                bert_pool_toxic, 
                bert_pool_nontoxic, 
                toxic_reconstructed, 
                nontoxic_reconstructed, 
                toxic_latent_semantic, 
                nontoxic_latent_semantic,
                toxic_latent_style,
                nontoxic_latent_style
                ):

        toxic_reconstruction_loss = F.mse_loss(toxic_reconstructed, bert_pool_toxic)
        nontoxic_reconstruction_loss = F.mse_loss(nontoxic_reconstructed, bert_pool_nontoxic)
        
        semantic_loss = F.mse_loss(toxic_latent_semantic, nontoxic_latent_semantic)

        # disentanglement_product = torch.matmul(toxic_latent_semantic, toxic_latent_style.t())
        # disentanglement_frobenius_norm = torch.norm(disentanglement_product, p='fro')  # Compute the Frobenius norm of the product
        # disentanglement_loss = disentanglement_frobenius_norm ** 2  # Square the Frobenius norm to get the loss

        output = output.view(-1, output.size(-1))
        target = target.view(-1)
        token_loss = F.cross_entropy(output, target, ignore_index=gpt_tokenizer.pad_token_id)

        total_loss = 5.* toxic_reconstruction_loss + \
                     5.* nontoxic_reconstruction_loss + \
                     semantic_loss + \
                     token_loss
                    #  disentanglement_loss + \
        
        return total_loss

In [None]:
 # Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the training parameters
batch_size = 8
num_epochs = 20
learning_rate = 1e-3

bert_checkpoint = 'distilbert-base-uncased'
gpt_checkpoint = 'gpt2'

bert_tokenizer = DistilBertTokenizer.from_pretrained(bert_checkpoint)
gpt_tokenizer = GPT2Tokenizer.from_pretrained(gpt_checkpoint)

# Set up the data loaders
train_dataset = TextStyleTransferDataset(data_path=train_file, 
                                         encoder_tokenizer = bert_tokenizer, 
                                         decoder_tokenizer = gpt_tokenizer,
                                         max_len=128)
eval_dataset = TextStyleTransferDataset(data_path=dev_file, 
                                        encoder_tokenizer = bert_tokenizer, 
                                        decoder_tokenizer = gpt_tokenizer,
                                        max_len=128)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

# Set up the model
model = TextStyleTransferModel()
model.to(device)

# Set up the optimizer and loss function
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = StyleTransferLoss()
losses = []

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0

    prog_bar = tqdm(train_loader, total=len(train_loader))

    prog_bar.set_description(f"Epoch {epoch+1}/{num_epochs}")

    for i, data in enumerate(prog_bar):
        source_inputs = data['source_inputs']
        gpt_inputs = data['gpt_inputs']
        target_inputs = data['target_inputs']
        target_gpt_inputs = data['target_gpt_inputs']

        input_ids = source_inputs['input_ids'].squeeze(1).to(device)
        attention_mask = source_inputs['attention_mask'].squeeze(1).to(device)
        gpt_ids = gpt_inputs['input_ids'].squeeze(1).to(device)
        # gpt_masks = gpt_inputs['attention_mask'].squeeze(1).to(device)

        target_ids = target_inputs['input_ids'].squeeze(1).to(device)
        target_attention_mask = target_inputs['attention_mask'].squeeze(1).to(device)
        target_gpt_ids = target_gpt_inputs['input_ids'].squeeze(1).to(device)
        # target_gpt_masks = target_gpt_inputs['attention_mask'].squeeze(1).to(device)

        optimizer.zero_grad()
        
        gpt_output, bert_pool_toxic, bert_pool_nontoxic, toxic_reconstructed, nontoxic_reconstructed, toxic_latent_semantic, nontoxic_latent_semantic, toxic_latent_style, nontoxic_latent_style = model(input_ids, attention_mask, target_ids, target_attention_mask)

        loss = criterion(gpt_output, 
                         target_gpt_ids, 
                         bert_pool_toxic, 
                         bert_pool_nontoxic, 
                         toxic_reconstructed, 
                         nontoxic_reconstructed, 
                         toxic_latent_semantic, 
                         nontoxic_latent_semantic,
                         toxic_latent_style,
                         nontoxic_latent_style
                         )
        
        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        prog_bar.set_postfix(loss=running_loss/(i+1))
    
    losses.append(running_loss)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- T

In [None]:
test_text = "bitch ass get out of here you don't belong"
test_inputs = gpt_tokenizer(test_text)
test_input_id = test_inputs['input_ids']
test_input_mask = test_inputs['attention_mask']
test_input_id = torch.tensor(test_input_id).unsqueeze(0).to(device)
test_input_mask = torch.tensor(test_input_mask).unsqueeze(0).to(device)

In [None]:
decoding_kwargs = {
              "num_return_sequences": 1,
              "top_p": 0.9,  # For nucleus sampling
              "top_k": 50,  # For top-k sampling
              "temperature": 0.8,  # For temperature sampling
              "do_sample": True,  # Set to True for sampling-based methods
              "max_length": 128,  # The maximum length of the generated text
              "eos_token_id": model.gpt_model.config.eos_token_id
          }

In [None]:
model.generate(test_input_id, test_input_mask)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[50256,    13,    13,    13,    13,    30,    13,    13,    13,    13,
           407,  2089,   329,   257,   318,   318,   257,   318,   340,   470,
           318,   470,   257,   257,   345,   340,    13,    13,    13,    13,
            13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
            13,    13,    13,   986,    13,    30,    13,    13,    13,    13,
            13,    13,    13,    30,    13,    30,    13,    13,    13,    13,
             0,    13,    13,    13,    13,    13,    13,    13,     0,    13,
            13,    13,    30,    13,  1639,   284,   389,   287,   345,    11,
            13,    40,    13,    13,  5812,   257,   345,   651,   389,   262,
           319,   345,    13,   407,   389,   821,   257,   470,    11,   284,
           447,   588,   284,  1639,    13,   338,   284,    11,    11,   470,
           470,   345,   318,  5812,   284,   284,   470,   257,   257,   319,
           470,   326,    13,   407,  1639,  1212,  

In [None]:
gpt_tokenizer.batch_decode(model.generate(test_input_id, toxic_attention_mask=None), skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[".. it to you think up it to just't up't you a to of? are you. you not for for....?....?................. it you is like a a a the.?. that you is the of's so of. people like I and're to,, so is with it're. people be get theyYou for't the be people,?? get't and are a a a to are, don to,'s this are is YouNo to you, you are of and you"]

In [None]:
gpt_output.shape

torch.Size([1, 128, 50257])

In [None]:
gpt_model.generate(inputs_embeds=test)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


RuntimeError: ignored

In [None]:
gpt_model = GPT2LMHeadModel.from_pretrained("gpt2")

In [None]:
semantic_latent_dim = 8
style_latent_dim = 4
fusion_dim = 32

bert_model_toxic = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model_nontoxic = DistilBertModel.from_pretrained('distilbert-base-uncased')
gpt_model = GPT2LMHeadModel.from_pretrained("gpt2")

toxic_semantic = nn.Linear(in_features = bert_model_toxic.config.hidden_size,
                                out_features = semantic_latent_dim,
                                bias=True)
toxic_style = nn.Linear(in_features = bert_model_toxic.config.hidden_size,
                              out_features = style_latent_dim,
                              bias=True)

nontoxic_semantic = nn.Linear(in_features = bert_model_nontoxic.config.hidden_size,
                                    out_features = semantic_latent_dim,
                                    bias=True)
nontoxic_style = nn.Linear(in_features = bert_model_nontoxic.config.hidden_size,
                                out_features = style_latent_dim,
                                bias=True)

fusion = nn.Linear(in_features = semantic_latent_dim + style_latent_dim,
                        out_features = fusion_dim,
                        bias=True)

hidden1 = nn.Linear(fusion_dim, int(5*fusion_dim))
# self.hidden2 = nn.Linear(self.hidden1.in_features, int(5*self.hidden1.out_features))

final_layer = nn.Linear(in_features = hidden1.out_features,
                              out_features = bert_model_toxic.config.vocab_size,
                              bias=True)

# mapping_hidden = nn.Linear(bert_model_toxic.config.vocab_size, gpt_model.config.vocab_size).to(device)
mapping_model = nn.Linear(bert_model_toxic.config.vocab_size, gpt_model.config.vocab_size)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight']
- T

In [None]:
# Encoding
toxic_bert = bert_model_toxic(input_ids, attention_mask=attention_mask).last_hidden_state
nontoxic_bert = bert_model_nontoxic(target_ids, attention_mask=target_attention_mask).last_hidden_state

# Grab the CLS token or the mean of last layer
bert_pool_toxic = toxic_bert.mean(dim=1)
bert_pool_nontoxic = nontoxic_bert.mean(dim=1)
# content_latent = content_outputs[0]

# Map the latent representations of toxic/non-toxic semantic/style
toxic_semantic_latent = toxic_semantic(bert_pool_toxic)
toxic_style_latent = toxic_style(bert_pool_toxic)

nontoxic_semantic_latent = toxic_semantic(bert_pool_nontoxic)
nontoxic_style_latent = toxic_style(bert_pool_nontoxic)

# This is the key step where we fuse the original semantic with non-toxic style
fusion_latent = torch.cat((toxic_semantic_latent, nontoxic_style_latent),dim=1)
fusion_latent = fusion(fusion_latent)

# Pass through some dense layers to generate tokens for GPT decoder
h1 = hidden1(fusion_latent)
h1 = nn.ReLU()(h1)
h1 = nn.Dropout(0.2)(h1)

# h2 = self.hidden2(h1)
# h2 = nn.ReLU()(h2)
# h2 = nn.Dropout(0.2)(h2)

mapping = final_layer(h1)

# Mapping BERT vocab ids into GPT2 vocab ids
# mapping = mapping_hidden(mapping)
mapping = nn.ReLU()(mapping)
mapping = mapping_model(mapping)

mapped_embeddings = gpt_model.transformer.wte(mapping.argmax(-1))
mapping_repeated = mapped_embeddings.unsqueeze(1).repeat(1, 128, 1)

gpt_outputs = gpt_model(inputs_embeds=mapping_repeated, return_dict=True)
gpt_logits = gpt_outputs.logits



In [None]:
gpt_model.transformer.wte

Embedding(50257, 768)

In [None]:
gpt_model.transformer.wte(mapping.argmax(-1)).shape

torch.Size([4, 768])

In [None]:
bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
device = 'cpu'
train_dataset = TextStyleTransferDataset(data_path=train_file, 
                                         encoder_tokenizer = bert_tokenizer,
                                         decoder_tokenizer = gpt2_tokenizer, 
                                         max_len=128)
eval_dataset = TextStyleTransferDataset(data_path=dev_file, 
                                        encoder_tokenizer = bert_tokenizer,
                                        decoder_tokenizer = gpt2_tokenizer,
                                        max_len=128)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=4, shuffle=True)

for data in train_loader:
    source_inputs = data['source_inputs']
    target_inputs = data['target_inputs']
    gpt_inputs = data['gpt_inputs']
    target_gpt_inputs = data['target_gpt_inputs']

    input_ids = source_inputs['input_ids'].squeeze(1).to(device)
    attention_mask = source_inputs['attention_mask'].squeeze(1).to(device)
    gpt_ids = gpt_inputs['input_ids'].squeeze(1).to(device)
    gpt_masks = gpt_inputs['attention_mask'].squeeze(1).to(device)

    target_ids = target_inputs['input_ids'].squeeze(1).to(device)
    target_attention_mask = target_inputs['attention_mask'].squeeze(1).to(device)
    target_gpt_ids = target_gpt_inputs['input_ids'].squeeze(1).to(device)
    target_gpt_masks = target_gpt_inputs['attention_mask'].squeeze(1).to(device)

    break

In [None]:
semantic_latent_dim = 8
style_latent_dim = 4
fusion_dim = 32

bert_model_toxic = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
bert_model_nontoxic = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
gpt_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

toxic_semantic = nn.Linear(in_features = bert_model_toxic.config.hidden_size,
                                out_features = semantic_latent_dim,
                                bias=True).to(device)
toxic_style = nn.Linear(in_features = bert_model_toxic.config.hidden_size,
                                out_features = style_latent_dim,
                                bias=True).to(device)

nontoxic_semantic = nn.Linear(in_features = bert_model_nontoxic.config.hidden_size,
                                    out_features = semantic_latent_dim,
                                    bias=True).to(device)
nontoxic_style = nn.Linear(in_features = bert_model_nontoxic.config.hidden_size,
                                out_features = style_latent_dim,
                                bias=True).to(device)

fusion = nn.Linear(in_features = semantic_latent_dim + style_latent_dim,
                out_features = fusion_dim,
                bias=True).to(device)

hidden1 = nn.Linear(fusion_dim, int(5*fusion_dim)).to(device)

final_layer = nn.Linear(in_features = hidden1.out_features,
                                out_features = bert_model_toxic.config.vocab_size,
                                bias=True).to(device)

# mapping_hidden = nn.Linear(bert_model_toxic.config.vocab_size, gpt_model.config.vocab_size)
mapping_model = nn.Linear(bert_model_toxic.config.vocab_size, gpt_model.config.n_embd).to(device)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight']
- T

In [None]:
decoding_kwargs = {
    "num_return_sequences": 1,
    "top_p": 0.9,  # For nucleus sampling
    "top_k": 50,  # For top-k sampling
    "temperature": 0.8,  # For temperature sampling
    "do_sample": True,  # Set to True for sampling-based methods
    "max_length": 128,  # The maximum length of the generated text
    "eos_token_id": gpt_model.config.eos_token_id
}

In [None]:
toxic_bert = bert_model_toxic(input_ids, attention_mask=attention_mask).last_hidden_state
nontoxic_bert = bert_model_nontoxic(target_ids, attention_mask=target_attention_mask).last_hidden_state

# Grab the CLS token or the mean of last layer
bert_pool_toxic = toxic_bert.mean(dim=1)
bert_pool_nontoxic = nontoxic_bert.mean(dim=1)
# content_latent = content_outputs[0]

# Map the latent representations of toxic/non-toxic semantic/style
toxic_semantic_latent = toxic_semantic(bert_pool_toxic)
toxic_style_latent = toxic_style(bert_pool_toxic)

nontoxic_semantic_latent = nontoxic_semantic(bert_pool_nontoxic)
nontoxic_style_latent = nontoxic_style(bert_pool_nontoxic)

# This is the key step where we fuse the original semantic with non-toxic style
fusion_latent = torch.cat((toxic_semantic_latent, nontoxic_style_latent),dim=1)
fusion_latent = fusion(fusion_latent)

# Pass through some dense layers to generate tokens for GPT decoder
h1 = hidden1(fusion_latent)
h1 = nn.ReLU()(h1)
h1 = nn.Dropout(0.2)(h1)

# h2 = self.hidden2(h1)
# h2 = nn.ReLU()(h2)
# h2 = nn.Dropout(0.2)(h2)

mapping = final_layer(h1)

# Mapping BERT vocab ids into GPT2 vocab ids
# f = mapping_hidden(mapping)
# f = nn.ReLU()(f)
f = mapping_model(mapping)

In [None]:
mapping_repeated = f.unsqueeze(1).repeat(1, 128, 1)

gpt_outputs = gpt_model(inputs_embeds=mapping_repeated, return_dict=True)
gpt_logits = gpt_outputs.logits

# output_tokens = gpt_model.generate(input_ids=None, input_embeds=mapping, **decoding_kwargs)


In [None]:
gpt_logits.shape

torch.Size([4, 128, 50257])

In [None]:
output_logits.shape

torch.Size([512, 50257])

In [None]:
target.shape

torch.Size([512])

In [None]:
output_logits = gpt_logits.view(-1, gpt_logits.size(-1))
target = target_gpt_ids.view(-1)
F.cross_entropy(output_logits, target)

RuntimeError: ignored

tensor(0.0346, device='cuda:0', grad_fn=<MseLossBackward0>)

In [None]:
import torch.nn.functional as F


In [None]:
gpt_model.to('cpu')
f = f.to('cpu')
gpt_model.generate(inputs_embeds = f[0][0],
                   max_length = 128,
                   num_beams=5,
                   top_p = 0.93,
                   do_sample=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


IndexError: ignored

In [None]:
torch.argmax(f, dim=2)

tensor([[12647, 12647, 12647,  ..., 12647, 12647, 12647],
        [19962, 19962, 19962,  ..., 19962, 19962, 19962],
        [30890, 30890, 30890,  ..., 30890, 30890, 30890],
        ...,
        [ 4384,  4384,  4384,  ...,  4384,  4384,  4384],
        [19962, 19962, 19962,  ..., 19962, 19962, 19962],
        [ 4384,  4384,  4384,  ...,  4384,  4384,  4384]])

In [None]:
semantic_latent_dim = 128
style_latent_dim = 8
fusion_dim = 256

bert_model_toxic = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
toxic_semantic = nn.Linear(bert_model_toxic.config.hidden_size, semantic_latent_dim).to(device)
toxic_style = nn.Linear(bert_model_toxic.config.hidden_size, style_latent_dim).to(device)

bert_model_nontoxic = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
nontoxic_semantic = nn.Linear(bert_model_nontoxic.config.hidden_size, semantic_latent_dim).to(device)
nontoxic_style = nn.Linear(bert_model_nontoxic.config.hidden_size, style_latent_dim).to(device)



Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight']
- T

In [None]:
fusion = nn.Linear(in_features = semantic_latent_dim + style_latent_dim,
                        out_features = fusion_dim,
                        bias=True).to(device)

hidden1 = nn.Linear(fusion_dim, int(1.5*fusion_dim)).to(device)
hidden2 = nn.Linear(hidden1.in_features, int(2*fusion_dim)).to(device)

final_layer = nn.Linear(in_features = hidden2.out_features,
                                out_features = bert_model_toxic.config.vocab_size,
                                bias=True).to(device)

In [None]:
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")

input_dim = bert_model_toxic.config.vocab_size
output_dim = gpt2_model.config.vocab_size

mapping_model = MappingModel(input_dim, output_dim).to(device)

In [None]:
toxic_bert = bert_model_toxic(input_ids, attention_mask=attention_mask).last_hidden_state
toxic_pool = toxic_bert.mean(dim=1)
toxic_latent_semantic = toxic_semantic(toxic_pool)
toxic_latent_style = toxic_style(toxic_pool)

nontoxic_bert = bert_model_nontoxic(target_ids, attention_mask=target_attention_mask).last_hidden_state
nontoxic_pool = nontoxic_bert.mean(dim=1)
nontoxic_latent_semantic = toxic_semantic(nontoxic_pool)
nontoxic_latent_style = nontoxic_style(nontoxic_pool)

In [None]:
fusion_latent = torch.cat((nontoxic_latent_semantic, nontoxic_latent_style),dim=1)
fusion_latent = fusion(fusion_latent)
h1 = hidden1(fusion_latent)
h1 = nn.ReLU()(h1)
h1 = nn.Dropout(0.2)(fusion_latent)
h2 = hidden2(h1)
h2 = nn.ReLU()(h2)
h2 = nn.Dropout(0.2)(h2)
f = final_layer(h2)
f = mapping_model(f)
f = f.unsqueeze(1).repeat(1, input_ids.shape[1], 1)
out = torch.argmax(f, dim=2)

In [None]:
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

In [None]:
nn.MSELoss()(out*target_gpt_masks, target_gpt_ids*target_gpt_masks)

RuntimeError: ignored

In [None]:
#NEXT: IMPLEMENT LOSS

torch.Size([8, 128])

In [None]:
input_ids.shape[1]

128

In [None]:
optimizer = AdamW(params=model.parameters(), lr=1e-3)

In [None]:
model = TextStyleTransferModel(128, 0)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.weight']
- T

In [None]:
toxic_semantic_latent, toxic_style_latent, nontoxic_semantic_latent, nontoxic_style_latent = model(input_ids, attention_mask, target_ids, target_attention_mask)

In [None]:
def loss_fn(mse, source_semantic, source_style, target_semantic, target_style, lambda_semantic=2.0, lambda_style=1.0):

    return lambda_semantic * mse(source_semantic, target_semantic) - lambda_style * mse(source_style, target_style)

In [None]:
mse = nn.MSELoss()
loss = loss_fn(mse, toxic_latent_semantic, nontoxic_latent_semantic, toxic_latent_style, nontoxic_latent_semantic, 8, 1) 

In [None]:
loss

tensor(0., grad_fn=<SubBackward0>)

In [None]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

RuntimeError: ignored

In [None]:
loss.item()

0.0

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the training parameters
batch_size = 4
num_epochs = 10
learning_rate = 0.001

bert_checkpoint = 'distilbert-base-uncased'
gpt_checkpoint = 'gpt2'

bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Set up the data loaders
train_dataset = TextStyleTransferDataset(data_path=train_file, 
                                         tokenizer = bert_tokenizer, 
                                         max_len=128)
eval_dataset = TextStyleTransferDataset(data_path=dev_file, 
                                        tokenizer = bert_tokenizer,
                                         max_len=128)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

for batch in train_loader:
    offensive_texts, neutralized_texts = batch

    input_texts = offensive_texts
    input_ids = batch['source_inputs']['input_ids'].squeeze(1).to(device)
    attention_mask = batch['source_inputs']['attention_mask'].squeeze(1).to(device)

    target_texts = neutralized_texts
    target_ids = batch['target_inputs']['input_ids'].squeeze(1).to(device)
    target_attention_mask = batch['target_inputs']['attention_mask'].squeeze(1).to(device)

    break
    

content_dim = 256
style_dim = 256

bert_model = DistilBertModel.from_pretrained(bert_checkpoint).to(device)
gpt_model = GPT2LMHeadModel.from_pretrained(gpt_checkpoint).to(device)


content_outputs = self.bert_model(content_input_ids, attention_mask=content_attention_mask).last_hidden_state
content_latent = content_outputs.mean(dim=1)
# content_latent = self.content_projection(content_outputs.mean(dim=1))

if style_input_ids is not None and style_attention_mask is not None:
    style_outputs = self.bert_model(style_input_ids, attention_mask=style_attention_mask).last_hidden_state

    style_embedding = self.style_encoder(content_latent) + self.style_decoder(style_latent)
    # style_latent = self.style_projection(style_outputs.mean(dim=1))

else:
    style_embedding = style_embeddings

# Concatenate content and style embeddings and fuse them
fused_latent = torch.cat((content_latent, style_embedding), dim=1)
fused_latent = self.fusion(fused_latent)

gpt_input = fused_latent.unsqueeze(1).repeat(1, content_input_ids.shape[1], 1)
gpt_output = self.gpt_model(inputs_embeds=gpt_input, attention_mask=content_attention_mask)

gpt_logits = gpt_output.logits

In [None]:
# content_dim = 256
# style_dim = 256

# bert_model = DistilBertModel.from_pretrained(bert_checkpoint).to(device)
# gpt_model = GPT2LMHeadModel.from_pretrained(gpt_checkpoint).to(device)

# style_vector = None
# # Content Encoder
# content_encoder = nn.Sequential(
#     nn.Linear(bert_model.config.hidden_size, content_dim),
#     nn.ELU()
# ).to(device)

# # Style Encoder
# style_encoder = nn.Sequential(
#     nn.Linear(bert_model.config.hidden_size, style_dim),
#     nn.ELU()
# ).to(device)

# # Decoder
# decoder = nn.Linear(gpt_model.config.n_embd, gpt_model.config.vocab_size, bias=False)

# content_outputs = bert_model(content_inputs, attention_mask=content_attention_mask).last_hidden_state
# content_latent = content_encoder(content_outputs.mean(dim=1))

# fusion = nn.Linear(content_dim + style_dim, gpt_model.config.n_embd).to(device)
# if style_vector is None:
#     style_outputs = bert_model(style_inputs, attention_mask=style_attention_mask).last_hidden_state
#     style_latent = style_encoder(style_outputs.mean(dim=1))
# else:
#     style_latent = style_vector

# # Decoding
# fused_latent = torch.cat((content_latent, style_latent), dim=1)
# fused_latent = fusion(fused_latent)
# gpt_input = fused_latent.unsqueeze(1).repeat(1, content_inputs.shape[1], 1)
# gpt_output = gpt_model(inputs_embeds=gpt_input)

In [None]:
torch.cuda.empty_cache()

In [None]:
_test = model.bert_model(input_ids, attention_mask)

In [None]:
_test.keys()

odict_keys(['last_hidden_state'])

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the training parameters
batch_size = 4
num_epochs = 10
learning_rate = 0.001

bert_checkpoint = 'distilbert-base-uncased'

bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Set up the data loaders
train_dataset = TextStyleTransferDataset(data_path=train_file, 
                                         tokenizer = bert_tokenizer, 
                                         max_len=128)
eval_dataset = TextStyleTransferDataset(data_path=dev_file, 
                                        tokenizer = bert_tokenizer,
                                         max_len=128)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

# Set up the model
# bert_model = DistilBertModel.from_pretrained(bert_checkpoint)
# gpt_model = GPT2LMHeadModel.from_pretrained(gpt_checkpoint)
model = TextStyleTransferModel(bert_model, gpt_model)
model.to(device)

# Set up the optimizer and loss function
optimizer = AdamW(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0

    prog_bar = tqdm(train_loader, total=len(train_loader))

    prog_bar.set_description(f"Epoch {epoch+1}")

    for batch in prog_bar:
        offensive_texts, neutralized_texts = batch

        input_texts = offensive_texts
        input_ids = batch['source_inputs']['input_ids'].squeeze(1).to(device)
        attention_mask = batch['source_inputs']['attention_mask'].squeeze(1).to(device)

        target_texts = neutralized_texts
        target_ids = batch['target_inputs']['input_ids'].squeeze(1).to(device)
        target_attention_mask = batch['target_inputs']['attention_mask'].squeeze(1).to(device)
        # target_ids = tokenizer.batch_encode_plus(target_texts, padding=True, truncation=True, return_tensors='pt')['input_ids']
        # target_attention_mask = (target_ids != tokenizer.pad_token_id)

        optimizer.zero_grad()
        logits, style_embedding = model(input_ids, attention_mask, target_ids, target_attention_mask)

        # # Generate text and compute loss
        # logits, generated_style = model(input_ids, attention_mask)
        # target_logits, target_style = model(target_ids, target_attention_mask)

        # loss = style_transfer_loss(
        #     generated_style, target_style,
        #     logits.view(-1, logits.shape[-1]), target_logits.view(-1, target_logits.shape[-1]),
        #     lambda_style=1.0, lambda_logits=1.0
        # )

        # Backpropagation and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    prog_bar.set_postfix(loss=running_loss/len(train_loader))

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch 1:   0%|          | 0/397 [00:00<?, ?it/s]


AttributeError: ignored

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set up the training parameters
batch_size = 4
num_epochs = 10
learning_rate = 0.001

bert_checkpoint = 'distilbert-base-uncased'
gpt_checkpoint = 'gpt2'

bert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Set up the data loaders
train_dataset = TextStyleTransferDataset(data_path=train_file, 
                                         tokenizer = bert_tokenizer, 
                                         max_len=128)
eval_dataset = TextStyleTransferDataset(data_path=dev_file, 
                                        tokenizer = bert_tokenizer,
                                         max_len=128)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

# Set up the model
bert_model = DistilBertModel.from_pretrained(bert_checkpoint)
gpt_model = GPT2LMHeadModel.from_pretrained(gpt_checkpoint)
model = TextStyleTransferModel(256, 256, bert_model, gpt_model)
model.to(device)

# Set up the optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=-100)


for epoch in range(num_epochs):
    running_loss = 0.0
    style_embeddings = []

    model.train()
    prog_bar = tqdm(train_loader, total=len(train_loader))

    prog_bar.set_description(f"Epoch {epoch+1}")
    for i, data in enumerate(prog_bar):

        source_input_ids = data['source_inputs']['input_ids'].squeeze(1).to(device)
        source_attention_mask = data['source_inputs']['attention_mask'].squeeze(1).to(device)
        target_input_ids = data['target_inputs']['input_ids'].squeeze(1).to(device)
        target_attention_mask = data['target_inputs']['attention_mask'].squeeze(1).to(device)

        optimizer.zero_grad()

        gpt_logits, style_embedding = model(source_input_ids, source_attention_mask, target_input_ids, target_attention_mask)
        
        target_logits = target_input_ids.view(-1)

        # loss = style_transfer_loss(gpt_logits, source_embedding, target_embedding, )

        loss = criterion(gpt_logits.view(-1, gpt_model.config.vocab_size), target_logits)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        prog_bar.set_postfix(loss=loss)

    style_embedding = style_embedding.mean(dim=0).detach().to('cpu').numpy()
    style_embeddings.append(style_embedding)
    # Evaluate the model
    model.eval()
    eval_loss = 0.0
    eval_acc = 0.0
    with torch.no_grad():
        for i, data in enumerate(eval_loader):
            content_inputs = data['content_input']['input_ids'].squeeze(1).to(device)
            content_attention_mask = data['content_input']['attention_mask'].squeeze(1).to(device)
            style_inputs = data['style_input']['input_ids'].squeeze(1).to(device)
            style_attention_mask = data['style_input']['attention_mask'].squeeze(1).to(device)

            outputs = model(content_inputs, content_attention_mask, style_inputs, style_attention_mask)
            loss = criterion(outputs.view(-1, gpt_model.config.vocab_size), style_inputs.view(-1))

            eval_loss += loss.item()

            _, predicted = torch.max(outputs.data, 2)
            eval_acc += (predicted == style_inputs).sum().item()

    print('Epoch %d train loss: %.3f eval loss: %.3f eval acc: %.3f' % (epoch + 1, running_loss / len(train_loader), eval_loss / len(eval_loader), eval_acc / (len(eval_dataset) * 128)), '\n')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch 1:   0%|          | 0/397 [00:00<?, ?it/s]


RuntimeError: ignored

# Testing

In [None]:
# load the saved model
model = TextStyleTransferModel(content_dim=256, style_dim=256, bert_model=bert_model, gpt_model=gpt_model)
model.load_state_dict(torch.load(save_path, map_location=device))
model.eval()

# input offensive text
offensive_text = "I hate this place so much. The service is terrible and the food is disgusting."

# tokenize offensive text
offensive_tokens = tokenizer(offensive_text, max_length=max_len, padding='max_length', truncation=True, return_tensors='pt')
offensive_inputs = offensive_tokens.input_ids.to(device)
offensive_attention_mask = offensive_tokens.attention_mask.to(device)

# generate neutralized text
with torch.no_grad():
    _, offensive_embedding = model.bert_model(offensive_inputs, attention_mask=offensive_attention_mask)
    neutralized_embedding = model.style_embedding(offensive_embedding)
    generated_ids = model.decoder.generate(neutralized_embedding, max_length=max_len, do_sample=True)
    neutralized_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
# print neutralized text
print(neutralized_text)

Sequential(
  (0): Linear(in_features=768, out_features=256, bias=True)
  (1): ReLU()
)

In [None]:
torch.tensor(style_embeddings[-1]).mean(dim=0).shape

torch.Size([768])

In [None]:
class TextStyleTransferModelFailed1(nn.Module):
    def __init__(self, content_dim, style_dim, bert_model, gpt_model):
        super(TextStyleTransferModel, self).__init__()
        
        self.bert_model = bert_model
        self.gpt_model = gpt_model

        self.style_encoder = nn.Linear(self.bert_model.config.hidden_size, self.bert_model.config.hidden_size)
        self.style_decoder = nn.Linear(self.bert_model.config.hidden_size, self.bert_model.config.hidden_size)

    def forward(self, input_ids, attention_mask):
        # Encode the input using DistilBERT
        bert_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        bert_hidden_states = bert_output.last_hidden_state
        bert_pooled_output = bert_output.pooler_output

        # Project the hidden states into the style embedding space
        style_embedding = self.style_encoder(bert_hidden_states).mean(dim=1)

        # Decode the style embedding back into the original hidden state space
        reconstructed_hidden_states = self.style_decoder(style_embedding.unsqueeze(1)).squeeze(1)

        # Feed the reconstructed hidden states through GPT-2 to generate text
        gpt_input = reconstructed_hidden_states
        gpt_output = self.gpt_model(inputs_embeds=gpt_input)

        return gpt_output.logits, style_embedding

    def generate(self, content_inputs, content_attention_mask, style_vector):
        with torch.no_grad():
            logits = self.forward(content_inputs, content_attention_mask, style_vector=style_vector)
            preds = torch.argmax(logits, dim=-1)
            return self.gpt_model.tokenizer.batch_decode(preds, skip_special_tokens=True)


def style_transfer_loss(generated_style, target_style, generated_logits, target_logits, lambda_style=1.0, lambda_logits=1.0):
    style_loss = lambda_style * nn.functional.mse_loss(generated_style, target_style)
    logits_loss = lambda_logits * nn.functional.mse_loss(generated_logits, target_logits)
    return style_loss + logits_loss

In [None]:
source_text = "I love going to the beach on weekends."
style_vector = torch.tensor(style_embeddings[-1]).mean(dim=0).unsqueeze(0).to(device) # replace with a learned style vector from the training data
tokenized_source = bert_tokenizer(source_text, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
content_inputs = tokenized_source['input_ids'].to(device)
content_attention_mask = tokenized_source['attention_mask'].to(device)
generated_text = model.generate(content_inputs, content_attention_mask, style_vector)

RuntimeError: ignored

In [None]:
for i, data in enumerate(train_loader):
    _content_inputs = data['content_input']['input_ids'].squeeze().to(device)
    _content_attention_mask = data['content_input']['attention_mask'].squeeze().to(device)
    _style_inputs = data['style_input']['input_ids'].squeeze().to(device)
    _style_attention_mask = data['style_input']['attention_mask'].squeeze().to(device)

    break

In [None]:
_content_inputs.shape

torch.Size([4, 128])

In [None]:
input_tokens = train_dataset.tokenizer('Why dont you fuck off', max_length=128, padding='max_length', truncation=True, return_tensors='pt')
style_tokens = train_dataset.tokenizer('You should leave', max_length=128, padding='max_length', truncation=True, return_tensors='pt')

In [None]:
logits = model(input_tokens['input_ids'].to(device), 
      input_tokens['attention_mask'].to(device), 
      style_tokens['input_ids'].to(device), 
      style_tokens['attention_mask'].to(device))

In [None]:
predicted_tokens = logits.argmax(dim=-1)
predicted_text = train_dataset.tokenizer.decode(predicted_tokens, skip_special_tokens=True)

TypeError: ignored

In [None]:
train_dataset.tokenizer.batch_decode(predicted_tokens, skip_special_tokens=True)

["you '."]

In [None]:
train_dataset.tokenizer.batch_decode(model.gpt_model.generate(test_tokens['input_ids'].to(device)))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Input length of input_ids is 128, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.


['[CLS] shush you retarded dumbass [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']