In [55]:

import torch.nn as nn
import os
from utils.dataloader import ParallelDataset, GPTDataset, SCIDataset, CombinedDataset
import torch
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

#https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial11/NF_image_modeling.ipynb#scrollTo=2MaRRpxL1MPG
# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


config = Config()
config.device

[rank: 0] Seed set to 42


device(type='cuda')

### model process
the model first encodes the input sentence x into a vector, in the forward process. 
for disentanglement the model disentangles the encoded vector into content and style vectors. 
in the reverse step the model generates the target style vector and then concatenate it with the source content vector to generate an input vector
it then uses the reversible encoder to decode this representation into the target sentence.


- for encoding
    - the model passes the input sentence into a pretrained GRU model with attention. The attention assigns a score to each token in the sentence that refers to their style level. This splits the input into style and content which then is passed into the coupling layer.
- disentanglement
    - the model aims to split the latent vector into a style and content vector. which indicate style information and content informaation. the split is made from the attention weights. 

In [63]:
#https://github.com/bentrevett/pytorch-seq2seq/blob/main/2%20-%20Learning%20Phrase%20Representations%20using%20RNN%20Encoder-Decoder%20for%20Statistical%20Machine%20Translation.ipynb

class Attention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn_fc = nn.Linear(
            (encoder_hidden_dim * 2) + decoder_hidden_dim, decoder_hidden_dim
        )
        self.v_fc = nn.Linear(decoder_hidden_dim, 1, bias=False)

    def forward(self, hidden, encoder_outputs):
        # hidden = [batch size, decoder hidden dim]
        # encoder_outputs = [src length, batch size, encoder hidden dim * 2]
        batch_size = encoder_outputs.shape[1]
        src_length = encoder_outputs.shape[0]
        # repeat decoder hidden state src_length times
        hidden = hidden.unsqueeze(1).repeat(1, src_length, 1)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        # hidden = [batch size, src length, decoder hidden dim]
        # encoder_outputs = [batch size, src length, encoder hidden dim * 2]
        energy = torch.tanh(self.attn_fc(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy = [batch size, src length, decoder hidden dim]
        attention = self.v_fc(energy).squeeze(2)
        # attention = [batch size, src length]
        return torch.softmax(attention, dim=1)

class Encoder(nn.Module):
    def __init__(
        self, config
    ):
        super().__init__()
        self.rnn = nn.GRU(config.hidden_dims, config.hidden_dims, bidirectional=True)
        self.fc = nn.Linear(config.hidden_dims * 2, config.hidden_dims)
        self.dropout = nn.Dropout(config.dropout)
        self.attention = Attention(config.hidden_dims, config.hidden_dims)
        # Add a linear layer to reduce the dimensions
        self.reduce_dim = nn.Linear(config.hidden_dims * 2, config.hidden_dims)
        
    def forward(self, src):
        # src = [src length, batch size]
        outputs, hidden = self.rnn(src)
        # outputs = [src length, batch size, hidden dim * n directions]
        # hidden = [n layers * n directions, batch size, hidden dim]
        # hidden is stacked [forward_1, backward_1, forward_2, backward_2, ...]
        # outputs are always from the last layer
        # hidden [-2, :, : ] is the last of the forwards RNN
        # hidden [-1, :, : ] is the last of the backwards RNN
        # initial decoder hidden is final hidden state of the forwards and backwards
        # encoder RNNs fed through a linear layer
        hidden = torch.tanh(
            self.fc(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1))
        )
        

        # embedded = [1, batch size, embedding dim]
        a = self.attention(hidden, outputs)
        # a = [batch size, src length]
        outputs = self.reduce_dim(outputs)
        # outputs = [src length, batch size, encoder hidden dim * 2]
        # hidden = [batch size, decoder hidden dim]
        return outputs, hidden, a
encoder = Encoder(config)
# Assume the following dimensions
src_length = 10
batch_size = 4
hidden_dim = config.hidden_dims
test_in = torch.randn(src_length, batch_size, config.encoder_hidden_dim)

outputs, hidden, attention_scores = encoder(test_in)
print(f'src length {src_length}, batch_size {batch_size}, hidden size {hidden_dim}')
print("Outputs shape:", outputs.shape)
print("Hidden shape:", hidden.shape)
print("Attention scores shape:", attention_scores.shape)


src length 10, batch_size 4, hidden size 256
Outputs shape: torch.Size([10, 4, 256])
Hidden shape: torch.Size([4, 256])
Attention scores shape: torch.Size([4, 10])


In [77]:
class AttentionAwareCoupling(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_dims
        self.num_heads = config.n_heads
        self.dropout = config.dropout
        self.transformer_block = nn.TransformerEncoderLayer(
            d_model=config.hidden_dims,
            nhead=config.n_heads,
            dim_feedforward=config.hidden_dims * 4,
            dropout=config.dropout,
            activation="relu"
        )
        self.encoder = Encoder(config)
        self.scaling_factor = nn.Parameter(torch.zeros(config.hidden_dims))
        self.fc_s = nn.Linear(config.hidden_dims, config.hidden_dims)
        self.fc_t = nn.Linear(config.hidden_dims, config.hidden_dims)
        self.threshold = config.attention_threshold

    def forward(self, x, a, ldj, reverse=False):
        # x = [src length, batch size, hidden dim]
        # a = [batch size, src length]
        # ldj = [batch size]
        # Split the input into content and style parts based on attention scores
        # first encode the sentence
        output, hidden, attention_scores = self.encoder(x)
        # filter the content and style by the attention generated from the model
        content_mask = attention_scores > self.threshold
        style_mask = attention_scores <= self.threshold
        # transpose the content and style masks so they can be used on tokens
        content_mask = content_mask.transpose(0, 1)
        style_mask = style_mask.transpose(0, 1)
        #filter the tokens to create xs and xc
        content_tokens = output * content_mask.unsqueeze(-1)
        style_tokens = output * style_mask.unsqueeze(-1)
        print(content_tokens.shape)
        # Pass the content tokens through the Transformer block
        content_out = self.transformer_block(content_tokens)

        # Generate the affine transformation parameters
        s = self.fc_s(content_out)
        t = self.fc_t(content_out)
        print(s.shape, style_tokens.shape)
        s_fac = self.scaling_factor.exp().view(1, -1, 1, 1)
        s_fac = s_fac.expand(s.shape[0], s.shape[1], -1, s.shape[3])

        print(s.shape, style_tokens.shape, s_fac)
        s = torch.tanh(s / s_fac) * s_fac

        if not reverse:
            # Apply the affine transformation only to the relevant dimensions of style_tokens
            
            style_tokens = torch.exp(s) * style_tokens + t
            ldj += s.sum(dim=[1, 2])
        else:
            style_tokens = (style_tokens - t) / torch.exp(s)
            ldj -= s.sum(dim=[1, 2])

        # Concatenate the content and style outputs
        output = torch.zeros_like(x)
        output[:, content_mask] = content_tokens
        output[:, style_mask] = style_tokens

        return output, ldj
    
layer = AttentionAwareCoupling(config)
# Assume the following dimensions
src_length = 10
batch_size = 4
hidden_dim = config.hidden_dims

# Create a dummy input tensor
dummy_input = torch.randn(src_length, batch_size, config.hidden_dims)



# Create dummy attention scores
dummy_attention_scores = torch.randn(batch_size, src_length)
dummy_ldj = torch.zeros(batch_size, device=config.device)
output, updated_ldj = layer(dummy_input, dummy_attention_scores, dummy_ldj)
print("Output shape:", output.shape)
print("Updated LDJ shape:", updated_ldj.shape)


AttributeError: 'Config' object has no attribute 'batch_sizeconfig'

In [None]:
class StyleFlow(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.encoder = Encoder()
        self.coupling_layers = nn.ModuleList(coupling_layers)
        self.prior = prior
        self.import_samples = import_samples
        self.embedding = nn.Embedding()

    def forward(self, sentences):
        return self._get_likelihood(sentences)

    def encode(self, sentences):
        # Encode the input sentences using the GRU encoder and attention-aware coupling layers
        
        z, ldj = sentences, torch.zeros(sentences.shape[0], device=self.device)
        for coupling_layer in self.coupling_layers:
            z, ldj = coupling_layer(z, attention_scores, ldj)
        return z, ldj

    def _get_likelihood(self, sentences, return_ll=False):
        z, ldj = self.encode(sentences)
        log_pz = self.prior.log_prob(z).sum(dim=[1, 2])
        log_px = ldj + log_pz
        nll = -log_px
        # Calculating bits per dimension (assuming binary representation of words)
        bpd = nll * np.log2(np.exp(1)) / np.prod(sentences.shape[1:])
        return bpd.mean() if not return_ll else log_px

    @torch.no_grad()
    def sample(self, sentence_shape, z_init=None):
        # Sample latent representation from prior
        if z_init is None:
            z = self.prior.sample(sample_shape=sentence_shape).to(self.device)
        else:
            z = z_init.to(self.device)
        
        # Transform z to sentence by inverting the coupling layers
        ldj = torch.zeros(sentence_shape[0], device=self.device)
        for coupling_layer in reversed(self.coupling_layers):
            z, ldj = coupling_layer.inverse(z, ldj)
        
        # Decode the latent representation to obtain the generated sentence
        generated_sentence = self.decoder(z)
        return generated_sentence

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        sentences = batch[0]
        loss = self._get_likelihood(sentences)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        sentences = batch[0]
        loss = self._get_likelihood(sentences)
        self.log('val_loss', loss)

    def test_step(self, batch, batch_idx):
        sentences = batch[0]
        samples = []
        for _ in range(self.import_samples):
            sentence_ll = self._get_likelihood(sentences, return_ll=True)
            samples.append(sentence_ll)
        sentence_ll = torch.stack(samples, dim=-1)
        sentence_ll = torch.logsumexp(sentence_ll, dim=-1) - np.log(self.import_samples)
        bpd = -sentence_ll * np.log2(np.exp(1)) / np.prod(sentences.shape[1:])
        bpd = bpd.mean()
        self.log('test_loss', bpd)

In [13]:
class Config():
    def __init__(self):
        self.n_heads = 4
        self.model_dims = 256
        self.hidden_dims = 256
        self.trans_embed_size = 256
        self.trans_pos_enc_size = 256
        self.CLN_bias = 256
        self.CLN_gain = 256
        self.NFC_len = 8
        self.loss_weight_1 = 0.5
        self.loss_weight_2 = 0.5
        self.loss_weight_3 = 1
        self.loss_weight_4 = 1
        self.attention_threshold = 0.5
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.input_dim = 3000 #vocabsize??
        self.embedding_dim = 2*8 #experimental
        self.dropout = 0
        self.encoder_hidden_dim = 256
        #maybe we just use RoBERTa for this???
        
config = Config()