This notebook adapts code from other public implementations of TextSETTR:
- base implementation of TextSETTR and dataset preparation: https://github.com/xiyan128/text_style_transfer_transformer
- add/delete rates: https://github.com/FabianBell/GuidedResearch/tree/textsettr

We also modify the original implementation of T5ForConditionalGeneration:
https://huggingface.co/transformers/v3.0.2/model_doc/t5.html

# Imports

In [None]:
!pip install transformers
!pip install pytorch-lightning
!pip install sentencepiece


from torch.utils.data import Dataset , DataLoader
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict , Counter
from itertools import chain
import ast
from transformers import T5TokenizerFast
from pytorch_lightning import Trainer
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.30.1-py3-none-any.whl (7.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m63.1 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m123.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90

In [None]:
##################################
#   Mount Drive to Save Models   #
##################################

from google.colab import drive
drive.mount('/content/drive')

# Folder path to where models will be saved
# Note: folders must already exist to save them there
folder_path = './drive/My Drive/SAT/Transfer'
models_path = folder_path + 'models/'

Mounted at /content/drive


## Set seed for reproducibility

In [None]:
def set_seed(seed):
  # np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# set seed for reproducibility
#set_seed(42)

# Custom T5-based TextSETTR implementation

### Modified T5ForConditionalGeneration
Adapted from: https://huggingface.co/transformers/v3.0.2/model_doc/t5.html

In [None]:
from transformers.models.t5.modeling_t5 import T5Stack, T5PreTrainedModel
from transformers.modeling_outputs import (BaseModelOutput,
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    Seq2SeqQuestionAnsweringModelOutput,
    Seq2SeqSequenceClassifierOutput,)
from transformers.models.t5.configuration_t5 import T5Config
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
import warnings
import copy

__HEAD_MASK_WARNING_MSG = """
The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
num_heads)`.
"""

In [None]:
# Modified the T5ForConditionalGeneration class to include
# a parallel encoder to use as style extractor
# Also addapted xiyan128's t5_extractor implementation
# TODO: Add credit at the top of notebook

class T5forStyleExtraction(T5PreTrainedModel):
    _keys_to_ignore_on_load_missing = [
        r"encoder\.embed_tokens\.weight",
        r"decoder\.embed_tokens\.weight",
        r"lm_head\.weight",
    ]
    _keys_to_ignore_on_load_unexpected = [
        r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
    ]

    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = T5Stack(encoder_config, self.shared)

        extractor_encoder_config = copy.deepcopy(config)
        extractor_encoder_config.is_decoder = False
        extractor_encoder_config.use_cache = False
        extractor_encoder_config.is_encoder_decoder = False
        self.extractor_encoder = T5Stack(extractor_encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = T5Stack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.post_init()

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.encoder.block))
        self.encoder.parallelize(self.device_map)
        self.extractor_encoder.parallelize(self.device_map)
        self.decoder.parallelize(self.device_map)
        self.lm_head = self.lm_head.to(self.decoder.first_device)
        self.model_parallel = True


    def deparallelize(self):
        self.encoder.deparallelize()
        self.extractor_encoder.deparallelize()
        self.decoder.deparallelize()
        self.encoder = self.encoder.to("cpu")
        self.extractor_encoder = self.extractor_encoder.to("cpu")
        self.decoder = self.decoder.to("cpu")
        self.lm_head = self.lm_head.to("cpu")
        self.model_parallel = False
        self.device_map = None
        torch.cuda.empty_cache()


    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.extractor_encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder

    def get_extractor_encoder(self):
        return self.extractor_encoder

    def get_decoder(self):
        return self.decoder

    def get_extractor_output(self,
        input_ids=None,
        use_cache_context_ids=None, # use cache is simply a trick to use the generator mixin
        use_cache_target_examplars_ids=None,
        use_cache_origin_examplars_ids=None,
        extr_lambda=3,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        extractor_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        context_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,):
        extractor_hidden = None
        if use_cache_context_ids is None:
            target_styles = ()
            for target_ids in use_cache_target_examplars_ids:
                extractor_hidden = self.extractor_encoder(
                  input_ids=target_ids,
                  attention_mask=attention_mask,
                  inputs_embeds=context_embeds,
                  head_mask=head_mask,
                  output_attentions=output_attentions,
                  output_hidden_states=output_hidden_states,
                  return_dict=return_dict,
              )[0]
                target_styles += (extractor_hidden,)

            original_styles = ()
            for origin_ids in use_cache_origin_examplars_ids:
                extractor_hidden = self.extractor_encoder(
                  input_ids=origin_ids,
                  attention_mask=attention_mask,
                  inputs_embeds=context_embeds,
                  head_mask=head_mask,
                  output_attentions=output_attentions,
                  output_hidden_states=output_hidden_states,
                  return_dict=return_dict,
              )[0]
                original_styles += (extractor_hidden,)

            input_style = self.extractor_encoder(
                  input_ids=input_ids,
                  attention_mask=attention_mask,
                  inputs_embeds=context_embeds,
                  head_mask=head_mask,
                  output_attentions=output_attentions,
                  output_hidden_states=output_hidden_states,
                  return_dict=return_dict,
              )[0]
            extractor_hidden = extr_lambda * (torch.mean(torch.vstack(target_styles), 0) - (torch.mean(torch.vstack(original_styles), 0))) + input_style

      # training
        else:
            if extractor_outputs is None:
                extractor_outputs = self.extractor_encoder(
                    input_ids=use_cache_context_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=context_embeds,
                    head_mask=head_mask,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                )
            elif return_dict and not isinstance(extractor_outputs, BaseModelOutput):
                extractor_outputs = BaseModelOutput(
                    last_hidden_state=extractor_outputs[0],
                    hidden_states=extractor_outputs[1] if len(extractor_outputs) > 1 else None,
                    attentions=extractor_outputs[2] if len(extractor_outputs) > 2 else None,)
            extractor_hidden = extractor_outputs[0]
        return extractor_hidden


    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        use_cache_extractor_outputs=None,
        ranges_prefix=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``

        Returns:

        Examples::

            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration

            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')

            >>> # training
            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits

            >>> # inference
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
            >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
            >>> # studies have shown that owning a dog is good for you.
        """
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # add extracted style to encoding
        hidden_states = encoder_outputs[0] + use_cache_extractor_outputs
        # a range of the noise modification is prepended to hidden states
        # note: in case model is being ran in parallel change below implementation
        ranges_prefix = ranges_prefix.to('cuda')
        hidden_states = torch.concat((ranges_prefix[:, None, :], hidden_states), 1)

        if attention_mask is not None:
            prepend = torch.ones(attention_mask.shape[0], dtype=torch.long).cuda()
            attention_mask = torch.cat([prepend[:, None], attention_mask], 1)

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # # If decoding with past key value states, only the last tokens
        # # should be given as an input
        # if past_key_values is not None:
        #     assert labels is None, "Decoder should not use cached key value states when training."
        #     if decoder_input_ids is not None:
        #         decoder_input_ids = decoder_input_ids[:, -1:]
        #     if decoder_inputs_embeds is not None:
        #         decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


    def prepare_inputs_for_generation(
        self,
        input_ids,
        use_cache_extractor_outputs=None,
        ranges_prefix=None,
        past=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs
    ):

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "use_cache_extractor_outputs": use_cache_extractor_outputs,
            "ranges_prefix": ranges_prefix,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
        }

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        return self._shift_right(labels)

    def _reorder_cache(self, past, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past is None:
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past

        reordered_decoder_past = ()
        for layer_past_states in past:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past

### Noise Utils

In [None]:
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [None]:
sent_length = 32

def drop_noise(sent, drop_rate):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], sent[randIdx + 1:]))
  return sent


special_tokens_set = set(tokenizer.all_special_ids)

def rand_token():
  # Differs from Riley et al.
  # Get a random token from the dict
  # (as opposed to a token from a different sentence, placed on that specific idx)
  t = np.random.randint(tokenizer.vocab_size)
  if t in special_tokens_set:
    return rand_token()
  return t


def add_noise(sent, drop_rate):
  for i in range(int(((sent > 1).sum() * drop_rate))):
    randIdx = np.random.choice(np.where((sent > 1).cpu())[0])
    sent = torch.concat((sent[:randIdx], torch.tensor([rand_token()]).cuda(), sent[randIdx + 1:]))
  return sent

def pad_sent(sent, target=sent_length):
  if sent.shape[0] > target:
    return sent[:target]
  return torch.concat((sent, torch.zeros(target - sent.shape[0], dtype=torch.long).cuda()))

### Add/Delete Rates Utils
Some functions adapted from Fabian Bell's TextSettr implementation: https://github.com/FabianBell/GuidedResearch/tree/textsettr

In [None]:
# NB! We treat noisy ids as the input (the actual input for the model)
# and original ids as the output (what we are trying to restore)
# Ex: input: cat really it; output: i really like it
def calculate_rates(original_ids, noisy_ids):
    original_counts = Counter(original_ids) # i really like it
    noisy_counts = Counter(noisy_ids)       # cat really it

    # [i, really, like, it, cat]
    all_tokens = set(list(original_counts.keys()) + list(noisy_counts.keys()))
    # [1, 0, 1, 0, -1] - positives are deleted, negatives are added
    total_counts = [original_counts.get(t, 0) - noisy_counts.get(t, 0) for t in all_tokens]
    deleted =  sum([i for i in total_counts if i > 0])
    added   = -sum([i for i in total_counts if i < 0])

    delete_rate = deleted/len(original_ids)
    add_rate = added/len(original_ids)

    return add_rate, delete_rate

# Sample each range width uniformly from
# [0,1], and uniformly sample the “alignment” of the true rate
# within the range
def get_ranges(center):
    width = np.random.rand()
    alignment = np.random.uniform(low=center-width, high=center+width)
    lower = max(alignment-width, 0)
    upper = min(alignment+width, 1)
    return lower, upper

# Hidden dims in T5 (CURR FOR T5-SMALL, ADD OPTION FOR BASE)
# dim = 512
dim = 768

def create_prefix(ranges):
    prefix = np.zeros(dim)
    prefix[np.arange(4)] = ranges
    return prefix.tolist()


def get_add_delete_rates(original_sents, noisy_sents):
    prefix = ()
    for i, sent in enumerate(original_sents):
        add_rate, del_rate = calculate_rates(sent, noisy_sents[i, :])
        add_lower, add_upper = get_ranges(add_rate)
        del_lower, del_upper = get_ranges(del_rate)
        ranges = [add_lower, add_upper, del_lower, del_upper]
        curr_prefix = torch.tensor(create_prefix(ranges))
        prefix += (curr_prefix,)
    return torch.vstack(prefix)

In [None]:
# Create add/delete tuning rates for NBT
tuning_range = [0.2, 0.4, 0.2, 0.4]
bt_tuning_range = [0.0, 0.0, 0.0, 0.0]


def get_prefix(add_del_rate):
    prefix = ()
    default_pref = torch.tensor(create_prefix(add_del_rate))
    prefix += (default_pref,)
    return torch.vstack(prefix)

default_prefix = get_prefix(tuning_range)

# Returns batch_size x dim tensor of zeroed add/delete rates
def get_bt_null_prefix(noisy_ids):
    prefix = ()
    for i, sent in enumerate(noisy_ids):
        bt_prefix = torch.tensor(create_prefix(bt_tuning_range))
        prefix += (bt_prefix,)
    return torch.vstack(prefix)

In [None]:
def apply_noise(sents):
  res = ()
  for i, sent in enumerate(sents):
    # we sample a separate noise probability for each sub-type of noise
    # use range from Riley et al. aka 20%-60%
    noise_probs = np.random.uniform(low=0.2, high=0.6, size=2)
    sent = drop_noise(sent, noise_probs[0])
    sent = add_noise(sent, noise_probs[1])
    sent = pad_sent(sent)
    res += (sent,)
  return torch.vstack(res)

### Output Utils

In [None]:
def tokenize(input):
  return tokenizer(input, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt").input_ids.cuda()

In [None]:
def peek_transfer_output(input, target_examplars, origin_examplars, extr_lambda, prefix=default_prefix):
  targets = ()
  for sent in target_examplars:
    targets += (tokenize(sent),)
  origins = ()
  for sent in origin_examplars:
    origins += (tokenize(sent),)
  input_ids = tokenize(input)
  extractor_output = model.net.get_extractor_output(input_ids=input_ids, use_cache_origin_examplars_ids=origins, use_cache_target_examplars_ids=targets, extr_lambda=extr_lambda)
  outputs = model.net.generate(input_ids=input_ids, use_cache_extractor_outputs=extractor_output, no_repeat_ngram_size=2, ranges_prefix=prefix)
  return tokenizer.decode(outputs[0], skip_special_tokens=True)

# TextSETTR Lightning Module

In [None]:
class TextSettrModel(LightningModule):
    def __init__(self):
      super().__init__()
      self.net = T5forStyleExtraction.from_pretrained("t5-base")
      self.net.extractor_encoder = copy.deepcopy(self.net.encoder)

    def training_step(self, batch, batch_idx):
      context_ids, input_ids, nbt_context_ids = batch[0], batch[1], batch[2]
      noisy_input_ids = apply_noise(input_ids)

      if np.random.choice([False, True]):
        # Noisy back translation
        bt_prefix = get_bt_null_prefix(noisy_input_ids)
        with torch.no_grad():
          nbt_extractor_output = self.net.get_extractor_output(use_cache_context_ids=nbt_context_ids)
          noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=nbt_extractor_output,
                                              ranges_prefix=bt_prefix, do_sample=True, max_length=sent_length, min_length=sent_length)

      # Get add/delete rates prefix
      prefix = get_add_delete_rates(input_ids, noisy_input_ids)

      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      return self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output, ranges_prefix=prefix).loss

    def validation_step(self, batch, batch_idx):
      context_ids, input_ids, nbt_context_ids = batch[0], batch[1], batch[2]

      noisy_input_ids = apply_noise(input_ids)
      bt_prefix = get_bt_null_prefix(noisy_input_ids)
      nbt_extractor_output = self.net.get_extractor_output(use_cache_context_ids=nbt_context_ids)

      noisy_input_ids = self.net.generate(input_ids=noisy_input_ids, use_cache_extractor_outputs=nbt_extractor_output,
                                          ranges_prefix=bt_prefix, do_sample=True, max_length=sent_length, min_length=sent_length)
      prefix = get_add_delete_rates(input_ids, noisy_input_ids)
      extractor_output = self.net.get_extractor_output(use_cache_context_ids=context_ids)
      self.log("val_loss", self.net(input_ids=noisy_input_ids, labels = input_ids, use_cache_extractor_outputs=extractor_output, ranges_prefix=prefix).loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.net.parameters(), 1e-3)

# Dataset Prep
Using 1/100 random sample from Amazon Dataset (xiyan128's sample)
Planning to train on more data later

In [None]:
raw_data = pd.read_csv(folder_path + "/balanced_amazon_cond.csv", converters={1:ast.literal_eval})
# raw_data = raw_data.sample(n=100000)

In [None]:
raw_data = raw_data.groupby(['type']).apply(lambda grp: grp.sample(n=80000))
raw_data = raw_data[["sents"]]

In [None]:
print(raw_data["sents"].apply(lambda x: len(x) - 1).sum())
raw_data["cumlen"] = raw_data["sents"].apply(lambda x: len(x) - 1).cumsum() - 1
raw_data["len"] = raw_data["sents"].apply(lambda x: len(x) - 1)
raw_data = raw_data.set_index("cumlen", drop=False)

pd.options.display.max_colwidth = 150

626266


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  raw_data["cumlen"] = raw_data["sents"].apply(lambda x: len(x) - 1).cumsum() - 1


In [None]:
# initialize tokenizer for Dataset building
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [None]:
import random

def get_random_entry(id):
  indices = raw_data["cumlen"].tolist()
  iidx = random.choice(indices)

  return random.choice(raw_data["sents"].loc[iidx])

In [None]:
class TextSETTRDataset(Dataset):
    def __init__(self,data):
        self.data = data
        self.len = self.data["sents"].apply(lambda x: len(x) - 1).sum()

    def __len__(self):
        return self.len

    def to_token(self,sentence):
        return tokenizer.encode(sentence, max_length=sent_length, truncation=True, padding="max_length", return_tensors="pt")[0]

    def get_inputs(self, idx):
      iidx = idx
      while iidx not in raw_data.index:
        iidx += 1
      line = raw_data["sents"].loc[iidx]
      base = idx - iidx - 2

      return (line[base], line[base + 1], get_random_entry(iidx))

    def __getitem__(self,index):
        context, input, nbt_context = self.get_inputs(index)
        return self.to_token(context), self.to_token(input), self.to_token(nbt_context)

In [None]:
batch_size = 64

class TextSETTRDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        train_dataset, val_dataset = train_test_split(raw_data, test_size=0.01)
        self.train = TextSETTRDataset(train_dataset)
        self.test = TextSETTRDataset(val_dataset)
        self.val = TextSETTRDataset(val_dataset)

    def train_dataloader(self):
        return DataLoader(self.train , batch_size = batch_size , shuffle = True, num_workers=4)
    def test_dataloader(self):
        return DataLoader(self.test , batch_size = batch_size , shuffle = False, num_workers=4)
    def val_dataloader(self):
        return DataLoader(self.val , batch_size = batch_size , shuffle = False, num_workers=4)

# Train model

In [None]:
model = TextSettrModel()
module = TextSETTRDataModule()
checkpoint_callback = pl.callbacks.ModelCheckpoint(monitor="val_loss", every_n_epochs=1)
logger = TensorBoardLogger("logs", name="style_transfer")
trainer = Trainer(max_epochs = 2, default_root_dir=models_path, check_val_every_n_epoch=1, precision=32, logger=logger)
trainer.fit(model,module)

Downloading model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Some weights of T5forStyleExtraction were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor_encoder.block.2.layer.1.layer_norm.weight', 'extractor_encoder.block.7.layer.1.DenseReluDense.wo.weight', 'extractor_encoder.block.0.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.7.layer.0.SelfAttention.v.weight', 'extractor_encoder.block.4.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.3.layer.1.layer_norm.weight', 'extractor_encoder.block.11.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.0.layer.0.SelfAttention.q.weight', 'extractor_encoder.block.7.layer.1.layer_norm.weight', 'extractor_encoder.block.5.layer.1.layer_norm.weight', 'extractor_encoder.block.3.layer.0.SelfAttention.o.weight', 'extractor_encoder.block.5.layer.0.layer_norm.weight', 'extractor_encoder.block.0.layer.0.layer_norm.weight', 'extractor_encoder.block.9.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.8.layer.0.SelfAttention.o.weight', 'ext

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name | Type                 | Params
----------------------------------------------
0 | net  | T5forStyleExtraction | 332 M 
----------------------------------------------
332 M     Trainable params
0         Non-trainable params
332 M     Total params
1,330.128 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


In [None]:
trainer.save_checkpoint(models_path + "nbt_2_epochs.ckpt")

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/style_transfer

<IPython.core.display.Javascript object>

In [None]:
model = TextSettrModel.load_from_checkpoint(models_path + "balanced_2_epochs.ckpt")

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Some weights of T5forStyleExtraction were not initialized from the model checkpoint at t5-base and are newly initialized: ['extractor_encoder.block.0.layer.1.DenseReluDense.wi.weight', 'extractor_encoder.block.0.layer.1.layer_norm.weight', 'extractor_encoder.block.4.layer.1.DenseReluDense.wo.weight', 'extractor_encoder.block.4.layer.0.SelfAttention.q.weight', 'extractor_encoder.block.6.layer.1.DenseReluDense.wi.weight', 'extractor_encoder.block.1.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.10.layer.1.DenseReluDense.wi.weight', 'extractor_encoder.block.4.layer.0.layer_norm.weight', 'extractor_encoder.block.5.layer.1.layer_norm.weight', 'extractor_encoder.block.0.layer.0.SelfAttention.q.weight', 'extractor_encoder.block.1.layer.0.SelfAttention.v.weight', 'extractor_encoder.block.8.layer.1.DenseReluDense.wo.weight', 'extractor_encoder.block.0.layer.0.SelfAttention.k.weight', 'extractor_encoder.block.7.layer.1.DenseReluDense.wo.weight', 'extractor_encoder.block.10.layer.0.Sel

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
model.to('cuda')

neutral_examplars = ["Now I will ask some questions to understand your situation.",
                  "In previous conversations, have you considered other viewpoints presented?",
                  "Are you always blaming and accusing yourself for when something goes wrong?",
                     "Tell me next time you see hime.",
                   "I understand how you feel"]
sad_empathetic_examplars = ["May I ask if you feel able to consider other people's points of view?",
                      "If I may ask this, do you always see yourself at fault for anything that happens?",
                      "How about we go through a few questions and look into some approaches that might help you feel better?",
                      "I am sorry for asking this but can you tell me if you see him next time.",
                      "I appreciate your feelings, you are valid and heard."]

add_del_range = [0.8, 0.9, 0.0, 0.0]
prefix = get_prefix(add_del_range)

neutral_input = "Do you believe that you should be the saviour of someone else?"
lambda_factor = 9
peek_transfer_output(neutral_input, sad_empathetic_examplars, neutral_examplars, lambda_factor, prefix)

'Do you believe that you may believe this that we should be the saviour of someone'

In [None]:
model.to('cuda')

sad_examplars = ["Are you unhappy because someone else has made you feel bad, like you've been hurt by them and you're the victim?",
                  "So that I can help you better, could you tell me if a specific event caused you to feel this way?",
                  "Have you attempted exercise 11 anytime recently, and if so did it spark difficult emotions for you remembering this event?",
                  "When something goes wrong, do you always tend to blame yourself regardless of whether you really made a mistake?",
                 "I appreciate your efforts, even though it must be hard for you."]
angry_examplars = ["If I may ask, do you ever find yourself blaming other people for the way you feel?",
                   "Would you say that this feeling you are experiencing right now was caused by a specific event or events?",
                   "Did you find exercise 11 brought up situations causing you to become uncontrollably emotional?",
                   "When something does not work out, do you harshly blame and accuse yourself?",
                   "I'm sorry you feel that way and I would like to know more about what happened."]

add_del_range = [0.2, 1.0, 0.2, 0.8]
prefix = get_prefix(add_del_range)

neutral_input = "Are you frustrated right now with trying to control someone or something that you can't?"
lambda_factor = 12
peek_transfer_output(neutral_input, angry_examplars, sad_examplars, lambda_factor, prefix)

"Are you frustrated right now with trying to control someone or something that you can't control?"

In [None]:
model.to('cuda')

sad_examplars = ["Are you unhappy because someone else has made you feel bad, like you've been hurt by them and you're the victim?",
                  "So that I can help you better, could you tell me if a specific event caused you to feel this way?",
                  "Have you attempted exercise 11 anytime recently, and if so did it spark difficult emotions for you remembering this event?",
                  "When something goes wrong, do you always tend to blame yourself regardless of whether you really made a mistake?",
                 "I appreciate your efforts, even though it must be hard for you."]
paraphrased_sad_examplars = ["Do you feel discontented because another person has caused you emotional distress, as if they have inflicted harm upon you and you perceive yourself as the one suffering?",
                   "In order for me to assist you more effectively, could you provide information about whether a particular incident triggered the emotions you're experiencing?",
                   "Have you recently made any attempts at exercise 11, and if you have, did it elicit challenging emotions as you recalled this particular event?",
                   "Do you have a tendency to hold yourself responsible and attribute blame to yourself whenever something goes awry, irrespective of whether you actually made an error?",
                   "I value the endeavors you have made, recognizing that it may be challenging for you."]

add_del_range = [0.1, 0.8, 0.2, 0.9]
prefix = get_prefix(add_del_range)

neutral_input = "I want to ask you some more questions to better understand why you feel that way."
lambda_factor = 12
peek_transfer_output(neutral_input, sad_examplars, paraphrased_sad_examplars, lambda_factor, prefix)

'I just want to ask you some more questions to better understand why you feel that way.'

In [None]:
model.to('cuda')

formal_examplars = ["This was a remarkably thought-provoking read.",
                  "It is certainly amongst my favorites."
                  "We humbly request your presence at our gala on the 12th."]
informal_examplars = ["reading this rly makes u think",
                      "Its def one of my favs",
                      "come swing by our bbq next week if ya can make it"]
formal_input = "I hereby commit to never purchase anything from this institution in the future."
lambda_factor = 6
peek_transfer_output(formal_input, informal_examplars, formal_examplars, lambda_factor, prefix)

'im gonna commit to never purchase anything from this this im imgonna ill a'