In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from nltk.tokenize import word_tokenize
import nltk
import re
import os
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm

# Download NLTK resources
nltk.download('punkt_tab')
nltk.download('punkt', quiet=True)

# Text preprocessing
def preprocess_text(text):
    """Basic preprocessing for social media text"""
    if isinstance(text, str):
        text = text.lower()
        text = re.sub(r'http\S+', '', text)
        text = re.sub(r'@\S+', '@user', text)
        text = re.sub(r'#(\S+)', r'\1', text)
        return text
    return ""

# Dataset class for text generation
class TextGenerationDataset(Dataset):
    def __init__(self, texts, labels, vocab=None, max_len=30):
        self.texts = texts
        self.labels = labels
        self.max_len = max_len

        if vocab is None:
            self.build_vocab()
        else:
            self.vocab = vocab
            self.idx_to_word = {v: k for k, v in vocab.items()}
            self.vocab_size = len(vocab)

    def build_vocab(self):
        word_counts = Counter()
        for text in self.texts:
            tokens = word_tokenize(preprocess_text(text))
            word_counts.update(tokens)

        self.vocab = {'<PAD>': 0, '<START>': 1, '<END>': 2, '<UNK>': 3}
        for word, _ in word_counts.items():
            if word not in self.vocab:
                self.vocab[word] = len(self.vocab)

        self.idx_to_word = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        tokens = ['<START>'] + word_tokenize(preprocess_text(text)) + ['<END>']
        indices = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens[:self.max_len]]

        if len(indices) < self.max_len:
            indices += [self.vocab['<PAD>']] * (self.max_len - len(indices))
        else:
            indices = indices[:self.max_len]

        one_hot = torch.zeros(4)
        one_hot[label] = 1

        return torch.tensor(indices, dtype=torch.long), one_hot


class TextGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, latent_dim, output_dim, max_length, num_layers=4, num_heads=8, dropout=0.1):
        super(TextGenerator, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.output_dim = output_dim
        self.max_length = max_length
        self.num_layers = num_layers
        self.num_heads = num_heads

        # Word embeddings
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

        # Positional encodings
        self.pos_embedding = nn.Embedding(max_length, embedding_dim)
        self.scale = torch.sqrt(torch.tensor(embedding_dim, dtype=torch.float32))

        # Sentiment conditioning
        self.class_embedding = nn.Linear(output_dim, latent_dim)
        self.latent_to_hidden = nn.Linear(latent_dim * 2, hidden_dim)
        self.class_to_embedding = nn.Linear(latent_dim, embedding_dim)

        # Custom transformer decoder layers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.decoder_layers = nn.ModuleList([decoder_layer for _ in range(num_layers)])

        # Output projection
        self.output_projection = nn.Linear(embedding_dim, vocab_size)

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)

    def generate_square_subsequent_mask(self, sz):
        """Generate a causal mask to prevent attending to future tokens."""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
        return mask

    def forward(self, latent_vector, conditioning_vector, input_tokens=None):
        batch_size = latent_vector.size(0)
        device = latent_vector.device

        # Sentiment embedding
        class_embedding = self.class_embedding(conditioning_vector)  # (batch_size, latent_dim)
        combined = torch.cat([latent_vector, class_embedding], dim=1)  # (batch_size, latent_dim * 2)
        hidden = self.latent_to_hidden(combined)  # (batch_size, hidden_dim)
        class_embed = self.class_to_embedding(class_embedding)  # (batch_size, embedding_dim)

        if input_tokens is None:
            # During inference, start with <START> token
            input_tokens = torch.full((batch_size, 1), 1, dtype=torch.long, device=device)  # <START> token

        seq_len = input_tokens.size(1)
        positions = torch.arange(0, seq_len, device=device).unsqueeze(0).repeat(batch_size, 1)
        token_embeds = self.word_embedding(input_tokens) * self.scale
        pos_embeds = self.pos_embedding(positions)

        # Add sentiment embedding to token embeddings
        sentiment_embeds = class_embed.unsqueeze(1).repeat(1, seq_len, 1)  # (batch_size, seq_len, embedding_dim)
        decoder_input = self.dropout(token_embeds + pos_embeds + sentiment_embeds)

        # Create causal mask
        tgt_mask = self.generate_square_subsequent_mask(seq_len).to(device)

        # Dummy memory tensor (same shape as input, but not used for cross-attention)
        memory = torch.zeros_like(decoder_input)  # (batch_size, seq_len, embedding_dim)

        # Pass through custom decoder layers
        output = decoder_input
        for layer in self.decoder_layers:
            output = layer(output, memory=memory, tgt_mask=tgt_mask)  # Use dummy memory

        # Project to vocabulary
        logits = self.output_projection(output)  # (batch_size, seq_len, vocab_size)
        return logits

    def sample(self, latent_vector, conditioning_vector, temperature=1.0, max_length=None):
        if max_length is None:
            max_length = self.max_length
        batch_size = latent_vector.size(0)
        device = latent_vector.device
        self.eval()

        # Start with <START> token
        generated = torch.full((batch_size, 1), 1, dtype=torch.long, device=device)  # (batch_size, 1)

        with torch.no_grad():
            for t in range(max_length - 1):
                logits = self(latent_vector, conditioning_vector, generated)  # (batch_size, seq_len, vocab_size)
                logits = logits[:, -1, :]  # Take logits of the last token (batch_size, vocab_size)

                if temperature != 1.0:
                    logits = logits / temperature

                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

                generated = torch.cat([generated, next_token], dim=1)  # (batch_size, seq_len + 1)

                # Stop if <END> token is generated
                if (next_token == 2).all():  # <END> token
                    break

        return generated

# Sentiment Classifier (Updated with hooks)
class SentimentClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout):
        super(SentimentClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, num_layers=n_layers,
            bidirectional=True, dropout=dropout if n_layers > 1 else 0, batch_first=True
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, text):
        embedded = self.embedding(text)
        output, (hidden, cell) = self.lstm(embedded)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        return self.fc(hidden)

    def register_hooks(self, hook_layers):
        hooks, features = [], []
        def hook_function_fc(module, input, output, features_list):
            features_list.append(output.detach().clone())
        def hook_function_lstm(module, input, output, features_list):
            features_list.append(output[0].detach().clone())

        for layer_idx in hook_layers:
            if isinstance(layer_idx, int):
                layer = self.get_layer_by_idx(layer_idx)
                if layer is not None:
                    features.append([])
                    if layer_idx == 0:
                        hooks.append(layer.register_forward_hook(
                            lambda module, input, output, features_list=features[-1]: hook_function_fc(module, input, output, features_list)
                        ))
                    elif layer_idx == 1:
                        hooks.append(layer.register_forward_hook(
                            lambda module, input, output, features_list=features[-1]: hook_function_lstm(module, input, output, features_list)
                        ))
        return hooks, features

    def get_layer_by_idx(self, idx):
        if idx == 0:
            return self.fc
        elif idx == 1:
            return self.lstm
        return None

# New Loss Functions
def cosine_similarity_loss(features):
    if not features:
        return torch.tensor(0.0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    if len(features[0].shape) > 2:
        features_2d = features[0].mean(dim=1)
    else:
        features_2d = features[0]
    normalized_features = F.normalize(features_2d, p=2, dim=1)
    similarity_matrix = torch.mm(normalized_features, normalized_features.t())
    mask = torch.eye(similarity_matrix.size(0), device=similarity_matrix.device).bool()
    similarity_matrix = similarity_matrix.masked_fill(mask, 0)
    loss = similarity_matrix.sum() / (similarity_matrix.size(0) * (similarity_matrix.size(0) - 1))
    return loss

def feature_orthogonality_loss(features):
    if not features:
        return torch.tensor(0.0, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    if len(features[0].shape) > 2:
        features_2d = features[0].mean(dim=1)
    else:
        features_2d = features[0]
    gram_matrix = torch.mm(features_2d, features_2d.t())
    identity_matrix = torch.eye(gram_matrix.size(0), device=gram_matrix.device)
    loss = torch.mean((gram_matrix - identity_matrix) ** 2)
    return loss / (features_2d.size(0) * features_2d.size(1))

# Updated Training Function
def train_generator(model, iterator, optimizer, criterion, classifier, device, lambda_sent=1.0, lambda_cos=0.5, lambda_orth=0.01, features=None):
    epoch_loss, epoch_sent_loss, epoch_cos_loss, epoch_orth_loss = 0, 0, 0, 0
    model.train()
    classifier.train()

    for batch in iterator:
        texts, labels = batch
        texts, labels = texts.to(device), labels.to(device)
        batch_size = texts.size(0)
        latent_vector = torch.randn(batch_size, model.latent_dim).to(device)

        # Prepare input and target for teacher forcing
        input_tokens = texts[:, :-1]  # Exclude last token
        target_tokens = texts[:, 1:]  # Exclude first token (<START>)

        logits = model(latent_vector, labels, input_tokens)  # (batch_size, seq_len-1, vocab_size)
        logits_flat = logits.reshape(-1, model.vocab_size)
        targets_flat = target_tokens.reshape(-1)
        token_loss = criterion(logits_flat, targets_flat)

        # Generate tokens for sentiment evaluation
        generated_tokens = model.sample(latent_vector, labels)
        sentiment_preds = classifier(generated_tokens)
        sent_loss = F.cross_entropy(sentiment_preds, torch.argmax(labels, dim=1))

        cos_loss = sum(cosine_similarity_loss(feat) for feat in features if feat)
        orth_loss = sum(feature_orthogonality_loss(feat) for feat in features if feat)

        total_loss = token_loss + lambda_sent * sent_loss + lambda_cos * cos_loss
        # total_loss = token_loss + lambda_sent * sent_loss + lambda_cos * cos_loss + lambda_orth * orth_loss
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        epoch_loss += token_loss.item()
        epoch_sent_loss += sent_loss.item()
        epoch_cos_loss += cos_loss.item()
        epoch_orth_loss += orth_loss.item()

        for feat in features:
            feat.clear()

    return (epoch_loss / len(iterator), epoch_sent_loss / len(iterator),
            epoch_cos_loss / len(iterator), epoch_orth_loss / len(iterator))

# Updated Evaluation Function
def evaluate_generator(model, iterator, criterion, classifier, device, features=None):
    epoch_loss, total_correct, total_samples = 0, 0, 0
    model.eval()
    classifier.eval()

    with torch.no_grad():
        for batch in iterator:
            texts, labels = batch
            texts, labels = texts.to(device), labels.to(device)
            batch_size = texts.size(0)
            latent_vector = torch.randn(batch_size, model.latent_dim).to(device)

            # Teacher forcing for token prediction loss
            input_tokens = texts[:, :-1]
            target_tokens = texts[:, 1:]
            logits = model(latent_vector, labels, input_tokens)
            logits_flat = logits.reshape(-1, model.vocab_size)
            targets_flat = target_tokens.reshape(-1)
            loss = criterion(logits_flat, targets_flat)

            # Sentiment accuracy
            generated_tokens = model.sample(latent_vector, labels)
            sentiment_preds = classifier(generated_tokens)
            pred_labels = torch.argmax(sentiment_preds, dim=1)
            true_labels = torch.argmax(labels, dim=1)
            correct = (pred_labels == true_labels).float().sum().item()

            epoch_loss += loss.item()
            total_correct += correct
            total_samples += batch_size

            for feat in features:
                feat.clear()

    return epoch_loss / len(iterator), total_correct / total_samples

# Updated Generate Text Samples
def generate_text_samples(model, dataset, n_samples=10, temperature=1.0, device=None, epoch=0):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    latent_vectors = torch.randn(n_samples, model.latent_dim).to(device)
    sentiment_labels = torch.zeros(n_samples, model.output_dim)
    for i in range(n_samples):
        class_idx = i % model.output_dim
        sentiment_labels[i, class_idx] = 1
    sentiment_labels = sentiment_labels.to(device)

    token_indices = model.sample(latent_vectors, sentiment_labels, temperature=temperature)
    generated_texts = []
    for sample_idx in range(n_samples):
        tokens = [dataset.idx_to_word.get(token_idx.item(), "<UNK>")
                  for token_idx in token_indices[sample_idx]
                  if dataset.idx_to_word.get(token_idx.item()) not in ['<PAD>', '<START>', '<END>']]
        unique_ratio = len(set(tokens)) / len(tokens) if tokens else 0
        generated_texts.append(" ".join(tokens))

    sentiment_names = {0: 'Negative', 1: 'Neutral', 2: 'Positive', 3: 'Irrelevant'}
    print("\nGenerated Text Samples:")
    for i, text in enumerate(generated_texts):
        sentiment_idx = i % model.output_dim
        sentiment = sentiment_names[sentiment_idx]
        print(f"\nSample {i+1} ({sentiment}):")
        print(text)
        print(f"Unique token ratio: {unique_ratio:.3f}")

    return generated_texts

# Load Dataset
def load_dataset(file_path):
    """Load dataset from CSV file like in sentiment classification"""
    try:
        df = pd.read_csv(file_path, header=None)

        # Assign column names based on position
        if len(df.columns) >= 4:
            df.columns = ['ID', 'Platform', 'Sentiment', 'Text'] + [f'Extra_{i}' for i in range(len(df.columns) - 4)]
        else:
            print(f"Warning: File {file_path} has only {len(df.columns)} columns. Expected at least 4.")
            column_names = ['ID', 'Platform', 'Sentiment', 'Text']
            df.columns = column_names[:len(df.columns)]
            if len(df.columns) < 4:
                print("Error: Text column is missing from the dataset.")
                return None

        # Map sentiment labels to numerical values
        sentiment_map = {'Positive': 2, 'Neutral': 1, 'Negative': 0, 'Irrelevant': 3}
        df['target'] = df['Sentiment'].map(sentiment_map)

        # Check if any NaN values in target column
        if df['target'].isna().any():
            missing_labels = df[df['target'].isna()]['Sentiment'].unique()
            print(f"Warning: Found unmapped sentiment labels: {missing_labels}")
            df.loc[df['target'].isna(), 'target'] = 1  # Default to 'Neutral'

        return df
    except Exception as e:
        print(f"Error loading dataset from {file_path}: {e}")
        return None

# Main Function
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    train_file = '/kaggle/input/twitter-dataset/twitter_training.csv'
    val_file = '/kaggle/input/twitter-dataset/twitter_validation.csv'
    classifier_file = '/kaggle/input/social_media_sentiment_classifier/pytorch/default/1/social_media_sentiment_model.pt'
    vocab_file = 'vocab.json'

    for file_path in [train_file, val_file, classifier_file]:
        if not os.path.exists(file_path):
            print(f"File {file_path} not found!")
            return

    train_df = load_dataset(train_file)
    val_df = load_dataset(val_file)

    if train_df is None or val_df is None:
        print("Error loading datasets. Exiting.")
        return

    print(f"Training dataset: {len(train_df)} rows")
    print(f"Validation dataset: {len(val_df)} rows")

    # Check class distribution
    sentiment_names = {0: 'Negative', 1: 'Neutral', 2: 'Positive', 3: 'Irrelevant'}
    print("\nClass distribution in training set:")
    class_counts = train_df['target'].value_counts()
    for idx, count in class_counts.items():
        print(f"{sentiment_names[idx]}: {count} ({count/len(train_df)*100:.2f}%)")

    X_train, y_train = train_df['Text'].values, train_df['target'].values.astype(int)
    X_val, y_val = val_df['Text'].values, val_df['target'].values.astype(int)

    max_length = 50
    train_dataset = TextGenerationDataset(X_train, y_train, max_len=max_length)
    vocab = train_dataset.vocab
    vocab_size = train_dataset.vocab_size
    print(f"Vocabulary size: {vocab_size}")
    val_dataset = TextGenerationDataset(X_val, y_val, vocab=vocab, max_len=max_length)

    import json
    with open(vocab_file, 'w') as f:
        json.dump(vocab, f)
    print(f"Saved vocabulary to {vocab_file}")

    batch_size = 128
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    embedding_dim = 100
    hidden_dim = 256
    latent_dim = 512
    output_dim = 4
    n_layers = 2
    dropout = 0.5

    pretrained_vocab_size = 40338
    classifier = SentimentClassifier(pretrained_vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, dropout).to(device)
    classifier.load_state_dict(torch.load(classifier_file, map_location=device), strict=False)
# After loading the classifier
    classifier.eval()
    val_texts, val_labels = [], []
    for texts, labels in val_loader:
        val_texts.append(texts)
        val_labels.append(torch.argmax(labels, dim=1))
    val_texts = torch.cat(val_texts).to(device)
    val_labels = torch.cat(val_labels).to(device)

    with torch.no_grad():
        preds = classifier(val_texts)
        pred_labels = torch.argmax(preds, dim=1)
        classifier_acc = (pred_labels == val_labels).float().mean().item()
    print(f"Classifier accuracy on real validation data: {classifier_acc:.3f}")

    hooks, features = classifier.register_hooks([0, 1])

    generator = TextGenerator(
        vocab_size,
        embedding_dim=100,
        hidden_dim=256,
        latent_dim=512,
        output_dim=4,
        max_length=50,
        num_layers=2,
        num_heads=4,
        dropout=0.1
    ).to(device)
    optimizer = optim.Adam(generator.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    output_dir = "generated_texts"
    os.makedirs(output_dir, exist_ok=True)

    n_epochs = 30
    best_valid_loss = float('inf')
    patience = 8
    patience_counter = 0
    train_losses, val_losses, sent_losses, cos_losses, orth_losses, sentiment_accs = [], [], [], [], [], []

    print("\nStarting training...")
    for epoch in range(n_epochs):
        train_metrics = train_generator(generator, train_loader, optimizer, criterion, classifier, device,
                                        lambda_sent=0.5, lambda_cos=0.5, lambda_orth=0.5, features=features)  # Updated lambdas
        val_loss, sentiment_acc = evaluate_generator(generator, val_loader, criterion, classifier, device, features)

        train_loss, train_sent_loss, train_cos_loss, train_orth_loss = train_metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        sent_losses.append(train_sent_loss)
        cos_losses.append(train_cos_loss)
        orth_losses.append(train_orth_loss)
        sentiment_accs.append(sentiment_acc)

        print(f'Epoch: {epoch+1}/{n_epochs}')
        print(f'Train Loss: {train_loss:.3f}, Sent Loss: {train_sent_loss:.3f}, Cos Loss: {train_cos_loss:.3f}, Orth Loss: {train_orth_loss:.3f}')
        print(f'Val Loss: {val_loss:.3f}')
        print(f'Sentiment Accuracy: {sentiment_acc:.3f}')

        if val_loss < best_valid_loss:
            best_valid_loss = val_loss
            patience_counter = 0
            torch.save(generator.state_dict(), f'{output_dir}/generator_best.pt')
            print("Saved best model checkpoint.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                generator.load_state_dict(torch.load(f'{output_dir}/generator_best.pt'))
                break

        if (epoch + 1) % 5 == 0 or epoch == n_epochs - 1:
            generated_texts = generate_text_samples(generator, train_dataset, n_samples=8, temperature=1.0, device=device, epoch=epoch)
            with open(f"{output_dir}/generated_texts_epoch_{epoch+1}.txt", "w", encoding="utf-8") as f:
                sentiment_names = {0: 'Negative', 1: 'Neutral', 2: 'Positive', 3: 'Irrelevant'}
                f.write(f"Epoch {epoch+1} Generated Texts:\n\n")
                for i, text in enumerate(generated_texts):
                    sentiment_idx = i % output_dim
                    sentiment = sentiment_names[sentiment_idx]
                    f.write(f"Sample {i+1} ({sentiment}):\n{text}\n\n")
            torch.save(generator.state_dict(), f'{output_dir}/generator_epoch_{epoch+1}.pt')

    plt.figure(figsize=(12, 10))
    plt.subplot(3, 1, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(3, 1, 2)
    plt.plot(sentiment_accs, label='Sentiment Accuracy')
    plt.title('Sentiment Classification Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(3, 1, 3)
    plt.plot(cos_losses, label='Cosine Similarity Loss')
    plt.plot(orth_losses, label='Orthogonality Loss')
    plt.title('Feature Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.savefig(f"{output_dir}/training_progress.png")

    print("\nTraining completed.")
    print(f"Best validation loss: {best_valid_loss:.3f}")
    generator.load_state_dict(torch.load(f'{output_dir}/generator_best.pt'))
    final_texts = generate_text_samples(generator, train_dataset, n_samples=16, temperature=1.0, device=device)
    with open(f"{output_dir}/final_generated_texts.txt", "w", encoding="utf-8") as f:
        sentiment_names = {0: 'Negative', 1: 'Neutral', 2: 'Positive', 3: 'Irrelevant'}
        f.write("Final Generated Texts:\n\n")
        for i, text in enumerate(final_texts):
            sentiment_idx = i % output_dim
            sentiment = sentiment_names[sentiment_idx]
            f.write(f"Sample {i+1} ({sentiment}):\n{text}\n\n")

    for hook in hooks:
        hook.remove()

    print(f"\nText generation complete. Results saved to {output_dir}.")

if __name__ == "__main__":
    main()