In [None]:
import torch
import torch.nn as nn                                            
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn.functional as F


class Encoder(nn.Module):
    def __init__(self, bert_encoder, hidden_dim, z_dim):
        super(Encoder, self).__init__()
        self.bert = bert_encoder
        self.hidden2mean = nn.Linear(hidden_dim, z_dim)
        self.hidden2logvar = nn.Linear(hidden_dim, z_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        hidden_state = outputs.last_hidden_state[:, 0, :]
        mean = self.hidden2mean(hidden_state)
        logvar = self.hidden2logvar(hidden_state)
        return mean, logvar
    
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, vocab_size):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(z_dim, hidden_dim)
        self.embedding = nn.Embedding(vocab_size, hidden_dim)  
        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

    def forward(self, z, target_ids=None, teacher_forcing_ratio=0.5):
        h = torch.tanh(self.fc(z)).unsqueeze(0)  
        batch_size = z.size(0)
        max_length = target_ids.size(1) if target_ids is not None else 20  

        outputs = torch.zeros(batch_size, max_length, self.output_layer.out_features).to(z.device)
        
        input_token = torch.zeros(batch_size, 1, hidden_dim).to(z.device)

        for t in range(max_length):
            output, h = self.gru(input_token, h)
            output_logits = self.output_layer(output.squeeze(1))
            outputs[:, t, :] = output_logits
            
            if target_ids is not None and torch.rand(1).item() < teacher_forcing_ratio:
                input_token = self.embedding(target_ids[:, t]).unsqueeze(1)  
            else:
                _, top_token = output_logits.max(dim=1)
                input_token = self.embedding(top_token).unsqueeze(1)  
        
        return outputs


class SentenceVAE(nn.Module):
    def __init__(self, encoder, decoder, z_dim):
        super(SentenceVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.z_dim = z_dim

    def sample_z(self, mean, logvar):
        std = torch.exp(0.5 * logvar)  
        epsilon = torch.randn_like(std)  
        z = mean + std * epsilon  
        return z

    def forward(self, input_ids, attention_mask, target_ids=None, teacher_forcing_ratio=1.0):
        mean, logvar = self.encoder(input_ids, attention_mask)
        z = self.sample_z(mean, logvar) 
        recon_x = self.decoder(z, target_ids=target_ids, teacher_forcing_ratio=teacher_forcing_ratio)
        return recon_x, mean, logvar

def sample_from_logits(logits, temperature=1.0):
    logits = logits / temperature  
    probabilities = torch.softmax(logits, dim=-1) 
    return torch.multinomial(probabilities, 1).squeeze(-1)  


tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
distilbert_encoder = DistilBertModel.from_pretrained("distilbert-base-uncased")
hidden_dim = 768
z_dim = 16  
vocab_size = tokenizer.vocab_size

encoder = Encoder(distilbert_encoder, hidden_dim, z_dim)
decoder = Decoder(z_dim, hidden_dim, vocab_size)
model = SentenceVAE(encoder, decoder, z_dim=16) 

In [None]:
def load_state_dict(model, filepath):
    state_dict = torch.load(filepath, map_location=torch.device('cpu'))
    new_state_dict = {}
    for key, value in state_dict.items():
        new_key = key.replace("module.", "") if key.startswith("module.") else key
        new_state_dict[new_key] = value
    model.load_state_dict(new_state_dict)
 
load_state_dict(model, "best_sentence_model.pt")
model.eval()
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def generate_variations(model, input_text, max_length=256, temperature=0.7, top_k=50, num_variations=4, perturb_scale=0.1):
    encoding = tokenizer.encode_plus(
        input_text,
        add_special_tokens=True, 
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    model.eval()
    with torch.no_grad():
        mean, logvar = model.encoder(input_ids, attention_mask)

        variations = []
        for _ in range(num_variations):
            noise = torch.randn_like(mean) * perturb_scale
            z = model.sample_z(mean, logvar) + noise
            
            generated_ids = [tokenizer.cls_token_id]  
            input_token = model.decoder.embedding(torch.tensor([[tokenizer.cls_token_id]]).to(device))  
            h = torch.tanh(model.decoder.fc(z)).unsqueeze(0)  
            for _ in range(max_length):
                output, h = model.decoder.gru(input_token, h)
                logits = model.decoder.output_layer(output.squeeze(1)) / temperature 

                k = min(top_k, logits.size(-1)) 
                top_k_values, top_k_indices = torch.topk(logits, k)
                probabilities = F.softmax(top_k_values, dim=-1)
                
                next_token_index = torch.multinomial(probabilities, 1).item()  
                next_token_id = top_k_indices[0, next_token_index].item()  

                word = tokenizer.decode([next_token_id])

                generated_ids.append(next_token_id)
                if next_token_id == tokenizer.sep_token_id: 
                    break

                input_token = model.decoder.embedding(torch.tensor([[next_token_id]]).to(device))

            generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
            variations.append(generated_text)
        
    return variations

input_text = "This show is so good, I want to watch it again. I love it! The plot is amazing and the acting is great."
variations = generate_variations(model, input_text, temperature=0.9, top_k=50, num_variations=4, perturb_scale=0.1)
for i, sentence in enumerate(variations, 1):
    print(f"Variation {i}: {sentence}")

Variation 1: i have made the story is all the most comedy, this movie and it's a to make one movie of her.
Variation 2: despite its own, of the.s the's worth all the first ( and the'll be an intelligent, but i can be even for the it, and an, and is the only to look.
Variation 3: it's the film that is so much a very documentary inventive..
Variation 4: a one of the film - -, it's a a lot of the. or a story's life of the worst.
