In [38]:
import torch

print(torch.cuda.is_available())

True


In [39]:
import pandas as pd

df = pd.read_csv("hf://datasets/jstet/quotes-500k/quotes.csv")

In [40]:
df

Unnamed: 0,quote,author,category
0,"I'm selfish, impatient and a little insecure. ...",Marilyn Monroe,"attributed-no-source, best, life, love, mistak..."
1,You've gotta dance like there's nobody watchin...,William W. Purkey,"dance, heaven, hurt, inspirational, life, love..."
2,You know you're in love when you can't fall as...,Dr. Seuss,"attributed-no-source, dreams, love, reality, s..."
3,A friend is someone who knows all about you an...,Elbert Hubbard,"friend, friendship, knowledge, love"
4,Darkness cannot drive out darkness: only light...,"Martin Luther King Jr., A Testament of Hope: T...","darkness, drive-out, hate, inspirational, ligh..."
...,...,...,...
499704,I do believe the most important thing I can do...,John C. Stennis,"Past, Believe, Help"
499705,I'd say I'm a bit antimadridista although I do...,Isco,"Team, Humility, Know"
499706,The future is now.,Nam June Paik,Now
499707,"In all my life and in the future, I will alway...",Norodom Sihamoni,"Life, My Life, Servant"


In [41]:
import swifter

data = df.copy()

# Author should be truncated after first comma using apply
data["author"] = data["author"].swifter.apply(
    lambda x: x.split(",")[0] if isinstance(x, str) else x
)

# Rename category to categories
data = data.rename(columns={"category": "categories"})

data["categories"] = data["categories"].swifter.apply(
    lambda x: ", ".join([tag.strip() for tag in str(x).split(",")])
)

# Remove the "attributed-no-source" category
data["categories"] = data["categories"].swifter.apply(
    lambda x: ", ".join(
        [tag for tag in str(x).split(",") if tag != "attributed-no-source"]
    )
)


Pandas Apply:   0%|          | 0/499709 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/499709 [00:00<?, ?it/s]

Pandas Apply:   0%|          | 0/499709 [00:00<?, ?it/s]

In [42]:
data_test = data.sample(frac=0.2, random_state=42)
data = data.drop(data_test.index)

In [43]:
data

Unnamed: 0,quote,author,categories
0,"I'm selfish, impatient and a little insecure. ...",Marilyn Monroe,"best, life, love, mistakes, out-of-contro..."
1,You've gotta dance like there's nobody watchin...,William W. Purkey,"dance, heaven, hurt, inspirational, life, ..."
3,A friend is someone who knows all about you an...,Elbert Hubbard,"friend, friendship, knowledge, love"
4,Darkness cannot drive out darkness: only light...,Martin Luther King Jr.,"darkness, drive-out, hate, inspirational, ..."
5,We accept the love we think we deserve.,Stephen Chbosky,"inspirational, love"
...,...,...,...
499702,The future isn't just a place you'll go. It's ...,Nancy Duarte,"You, Place, Will"
499703,The Christian of the future will be a mystic o...,Karl Rahner,"Christian, Will, Exist"
499704,I do believe the most important thing I can do...,John C. Stennis,"Past, Believe, Help"
499706,The future is now.,Nam June Paik,Now


In [44]:
data_train = data.sample(frac=0.99, random_state=42)
data_val = data.drop(data_train.index)
data_train = data_train.reset_index(drop=True)
data_val = data_val.reset_index(drop=True)

data_train.head()

Unnamed: 0,quote,author,categories
0,A black telephonereceiver was stuffed in the s...,M.L. Terese,"detective, mystery, p-hone"
1,Bailey took an exasperated breath and sat up i...,Heather McVea,"adolescence, reason, teenager"
2,I'm getting a daily email from Microsoft which...,Steven Magee,"10, access, account, all, along, daily, ..."
3,When you follow your heart you allow miraculou...,Menna van Praag,"destiny, dreams, goals, heart, miracles"
4,You have been complaining so long about your l...,Israelmore Ayivor,"accomplish, accomplishment, achievement, ac..."


In [13]:
def create_input_text(quote, author):
    """Create the input text with descriptive prompt"""
    return f'What tags or categories would best describe this quote: "{quote}" by {author}? Provide comma-separated tags.'

In [46]:
from datasets import Dataset
from tqdm.auto import tqdm

In [47]:
def process_data_in_chunks(df, chunk_size=50000):
    """Process dataframe in chunks, returning a dataset"""

    all_formatted_data = []

    # Get total number of chunks for progress tracking
    num_chunks = (len(df) + chunk_size - 1) // chunk_size

    for i in tqdm(
        range(0, len(df), chunk_size), total=num_chunks, desc="Processing chunks"
    ):
        # Extract chunk
        chunk = df.iloc[i : i + chunk_size].copy()

        # Directly create input_text column (more efficient than apply)
        chunk["input_text"] = [
            create_input_text(quote, author)
            for quote, author in zip(chunk["quote"], chunk["author"])
        ]

        # Use existing categories column
        chunk["target_text"] = chunk["categories"]

        # Select only the columns we need
        formatted_chunk = chunk[["input_text", "target_text"]]

        # Convert to records and add to result
        all_formatted_data.extend(formatted_chunk.to_dict("records"))

    # Create dataset from all processed data
    return Dataset.from_list(all_formatted_data)

In [48]:
# Process data in chunks
train_dataset = process_data_in_chunks(data_train)
val_dataset = process_data_in_chunks(data_val)

# Save datasets to disk
train_dataset.save_to_disk("train_dataset")
val_dataset.save_to_disk("val_dataset")

Processing chunks:   0%|          | 0/8 [00:00<?, ?it/s]

Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Saving the dataset (0/1 shards):   0%|          | 0/395769 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3998 [00:00<?, ? examples/s]

In [49]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [50]:
model_name = "google/flan-t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [51]:
def preprocess_function(examples):
    inputs = tokenizer(
        examples["input_text"], padding="max_length", truncation=True, max_length=256
    )
    outputs = tokenizer(
        examples["target_text"], padding="max_length", truncation=True, max_length=64
    )

    batch = {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "labels": outputs.input_ids.copy(),
    }

    # Replace pad token id with -100 so it's ignored in loss calculation
    batch["labels"] = [
        [(label if label != tokenizer.pad_token_id else -100) for label in labels]
        for labels in batch["labels"]
    ]

    return batch

In [52]:
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True, num_proc=16)


Map (num_proc=16):   0%|          | 0/395769 [00:00<?, ? examples/s]

In [53]:
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/3998 [00:00<?, ? examples/s]

In [54]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, jaccard_score

In [55]:
import numpy as np

In [56]:
def compute_metrics(eval_preds, tokenizer):
    """
    Compute metrics using semantic similarity and Jaccard similarity

    Args:
        eval_preds: Tuple containing (predictions, labels)
        tokenizer: The tokenizer to decode predictions and labels

    Returns:
        Dictionary with key metrics
    """
    # Load sentence transformer model (only once)
    global sentence_model
    if "sentence_model" not in globals():
        model_name = "all-MiniLM-L6-v2"
        print(f"Loading sentence transformer model: {model_name}")
        sentence_model = SentenceTransformer(model_name)
        device = "cuda" if torch.cuda.is_available() else "cpu"
        sentence_model.to(device)

    # Decode predictions and labels
    preds, labels = eval_preds

    # Clip predictions to vocab size and decode
    preds = np.clip(preds, 0, tokenizer.vocab_size - 1)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 with pad token id in labels and decode
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Clean and process predictions and labels
    processed_preds = []
    processed_labels = []

    for pred, label in zip(decoded_preds, decoded_labels):
        # Simple processing - split by comma and clean
        pred_tags = [tag.strip() for tag in pred.split(",") if tag.strip()]
        label_tags = [tag.strip() for tag in label.split(",") if tag.strip()]

        # Remove duplicates
        pred_tags = list(dict.fromkeys(pred_tags))

        processed_preds.append(", ".join(pred_tags))
        processed_labels.append(", ".join(label_tags))

    # Sample a subset for faster evaluation
    max_examples = min(50, len(processed_preds))
    sample_indices = np.random.choice(len(processed_preds), max_examples, replace=False)

    sample_preds = [processed_preds[i] for i in sample_indices]
    sample_labels = [processed_labels[i] for i in sample_indices]

    # Calculate semantic similarity
    with torch.no_grad():
        pred_embeddings = sentence_model.encode(sample_preds)
        label_embeddings = sentence_model.encode(sample_labels)

    # Calculate cosine similarity for each pair
    similarities = []
    for i in range(len(pred_embeddings)):
        sim = cosine_similarity(
            pred_embeddings[i].reshape(1, -1), label_embeddings[i].reshape(1, -1)
        )[0][0]
        similarities.append(float(sim))

    semantic_similarity = sum(similarities) / len(similarities) if similarities else 0.0

    # Calculate Jaccard similarity
    # Create a set of all unique tags
    all_tags = set()
    for pred, label in zip(sample_preds, sample_labels):
        pred_tags = set(t.strip() for t in pred.split(",") if t.strip())
        label_tags = set(t.strip() for t in label.split(",") if t.strip())
        all_tags.update(pred_tags)
        all_tags.update(label_tags)

    all_tags = list(all_tags)

    # Create multi-hot vectors for F1 and Jaccard calculations
    true_labels = []
    pred_labels = []

    for pred, label in zip(sample_preds, sample_labels):
        pred_tags = set(t.strip() for t in pred.split(",") if t.strip())
        label_tags = set(t.strip() for t in label.split(",") if t.strip())

        true_vec = [1 if tag in label_tags else 0 for tag in all_tags]
        pred_vec = [1 if tag in pred_tags else 0 for tag in all_tags]

        true_labels.append(true_vec)
        pred_labels.append(pred_vec)

    # Calculate Jaccard score
    jaccard = 0.0
    if all_tags and any(true_labels) and any(pred_labels):
        jaccard = jaccard_score(
            y_true=true_labels, y_pred=pred_labels, average="macro", zero_division=0
        )

    # Calculate final score
    score = semantic_similarity + 100.0 * jaccard

    # Print example outputs
    num_examples = min(3, len(sample_preds))
    for i in range(num_examples):
        print(f"\nExample {i}:")
        print(f"Prediction: {sample_preds[i]}")
        print(f"Reference: {sample_labels[i]}")
        print(f"Semantic similarity: {similarities[i]:.4f}")
        # Calculate Jaccard for this specific example
        pred_tags = set(t.strip() for t in sample_preds[i].split(",") if t.strip())
        label_tags = set(t.strip() for t in sample_labels[i].split(",") if t.strip())
        if pred_tags or label_tags:
            example_jaccard = (
                len(pred_tags.intersection(label_tags))
                / len(pred_tags.union(label_tags))
                if pred_tags.union(label_tags)
                else 0
            )
            print(f"Jaccard similarity: {example_jaccard:.4f}")

    return {
        "semantic_similarity": semantic_similarity,
        "jaccard_score": jaccard,
        "score": score,
    }

In [57]:
class TaggingTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Enhanced loss function for tag prediction that includes:
        - Basic repetition penalty
        - Malformed tag penalty
        - Jaccard similarity bonus
        """
        # Get standard loss using parent implementation
        outputs = model(**inputs)
        loss = outputs.loss

        # Only apply custom loss during training
        if self.args.do_train:
            # Get the generated logits and batch size
            logits = outputs.logits
            batch_size = logits.shape[0]

            with torch.no_grad():
                # Get predictions
                generated_ids = torch.argmax(logits, dim=-1)
                generated_texts = self.tokenizer.batch_decode(
                    generated_ids, skip_special_tokens=True
                )

                # Get reference texts if available
                labels = inputs.get("labels", None)
                reference_texts = None
                if labels is not None:
                    # Mask out padding tokens
                    labels_masked = labels.clone()
                    labels_masked[labels_masked == -100] = self.tokenizer.pad_token_id
                    reference_texts = self.tokenizer.batch_decode(
                        labels_masked, skip_special_tokens=True
                    )

                # Initialize loss components
                repetition_penalty = 0.0
                malformed_penalty = 0.0
                jaccard_bonus = 0.0

                for idx, text in enumerate(generated_texts):
                    # Split by comma to get individual tags
                    tags = [tag.strip() for tag in text.split(",") if tag.strip()]

                    if not tags:
                        continue

                    # Process reference tags if available
                    ref_tags = []
                    if reference_texts and idx < len(reference_texts):
                        ref_tags = [
                            tag.strip()
                            for tag in reference_texts[idx].split(",")
                            if tag.strip()
                        ]

                    # Calculate repetition penalty
                    unique_tags = set(tags)
                    if len(tags) > 0:
                        repetition_penalty += 1.0 - (len(unique_tags) / len(tags))

                    # Simple malformed tag detection
                    malformed_count = sum(
                        1
                        for tag in tags
                        if (
                            tag.endswith("-quot")
                            or tag.endswith("-s")
                            or tag.endswith("-")
                            or "-quotes-quotes" in tag
                        )
                    )
                    if len(tags) > 0:
                        malformed_penalty += malformed_count / len(tags)

                    # Jaccard similarity calculation when reference is available
                    if ref_tags and tags:
                        pred_set = set(tags)
                        ref_set = set(ref_tags)

                        # Calculate Jaccard similarity (intersection over union)
                        intersection = len(pred_set.intersection(ref_set))
                        union = len(pred_set.union(ref_set))

                        if union > 0:
                            jaccard_score = intersection / union
                            jaccard_bonus += jaccard_score

                # Average metrics across batch
                avg_repetition_penalty = (
                    repetition_penalty / batch_size if batch_size > 0 else 0
                )
                avg_malformed_penalty = (
                    malformed_penalty / batch_size if batch_size > 0 else 0
                )
                avg_jaccard_bonus = jaccard_bonus / batch_size if batch_size > 0 else 0

                # Apply penalties and bonuses to loss with simple weights
                custom_factor = (
                    1.0
                    + 0.5 * avg_repetition_penalty  # Repetition penalty
                    + 0.3 * avg_malformed_penalty  # Malformation penalty
                    - 100
                    * avg_jaccard_bonus  # Jaccard similarity bonus (higher is better)
                )

                # Clip to reasonable range
                custom_factor = max(0.8, min(1.5, custom_factor))

                # Apply factor to base loss
                loss = loss * custom_factor

        return (loss, outputs) if return_outputs else loss

In [58]:
model.gradient_checkpointing_enable()

In [59]:
def compute_metrics_wrapper(eval_preds):
    return compute_metrics(eval_preds, tokenizer)

In [60]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./flan-t5-semantic-tagger",
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    learning_rate=5e-5,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    generation_max_length=128,
    gradient_accumulation_steps=4,
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="score",
    greater_is_better=True,
)

# Update the model's generation config directly
model.config.no_repeat_ngram_size = 2  # Prevents repeating 2-grams
model.config.repetition_penalty = 2.0  # Penalizes token repetition
model.config.diversity_penalty = 0.0  # Diversity penalty is not used in order to trigger beam search
model.config.num_beam_groups = 1  # Diverse beam groups
model.config.num_beams = 5  # Beam search for better quality

In [61]:
trainer = TaggingTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_wrapper,
)

In [62]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss,Semantic Similarity,Jaccard Score,Score
200,2.7498,2.52832,0.371057,0.047619,5.132962
400,2.6823,2.462891,0.447852,0.059561,6.403964
600,2.6537,2.430664,0.426922,0.060658,6.492681
800,2.627,2.411621,0.423445,0.028625,3.285899
1000,2.6264,2.402588,0.422339,0.056985,6.120869
1200,2.6195,2.39624,0.450072,0.049231,5.373173
1400,2.6139,2.39209,0.463997,0.063635,6.827469
1600,2.6035,2.391113,0.463316,0.063107,6.773996
1800,2.6086,2.390137,0.449764,0.064981,6.947891
2000,2.6097,2.390625,0.454804,0.073995,7.854331


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.



Example 0:
Prediction: great-soul, soul-quotes
Reference: amazing, confidence, excellent-life, inspirational, lailah-gifty-akita-affirmations, self-esteem, self-love, self-mootivation, shine-your-light, soul-great-soul, wonderful-people, your-life
Semantic similarity: 0.3310
Jaccard similarity: 0.0000

Example 1:
Prediction: money, money-quotes
Reference: one, day
Semantic similarity: 0.1701
Jaccard similarity: 0.0000

Example 2:
Prediction: confidential-relationships, confidentiality, privacy-quotes
Reference: confidence, relationships
Semantic similarity: 0.4632
Jaccard similarity: 0.0000





Example 0:
Prediction: success, success-quotes
Reference: Make It Happen, You, Happen
Semantic similarity: 0.2621
Jaccard similarity: 0.0000

Example 1:
Prediction: menace, menace-quotes
Reference: male, men, menace
Semantic similarity: 0.6182
Jaccard similarity: 0.2500

Example 2:
Prediction: actor, actors, simian-performance-coach, peter-elliott
Reference: Time, Good, Performance
Semantic similarity: 0.2064
Jaccard similarity: 0.0000

Example 0:
Prediction: poetry, poetry-quotes, kate-tei-yamashita
Reference: characters, identity, karen-tei-yamashita, poems, poet, poetry, poets, self, stories, tropic-of-orange
Semantic similarity: 0.6919
Jaccard similarity: 0.0833

Example 1:
Prediction: love-quotes, christianity
Reference: coming-of-age, edgy-teen-fiction, inspirational-love, new-adult-contemporary-romance, teen-romance, young-adult-romance
Semantic similarity: 0.2267
Jaccard similarity: 0.0000

Example 2:
Prediction: love-quotes, losing-yourself, forgetting-that-you-are-special
Re

There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


TrainOutput(global_step=2319, training_loss=2.6462813982319964, metrics={'train_runtime': 10717.2795, 'train_samples_per_second': 110.784, 'train_steps_per_second': 0.216, 'total_flos': 4.065085989314888e+17, 'train_loss': 2.6462813982319964, 'epoch': 3.0})

In [64]:
trainer.save_model("flan-t5-semantic-tagger-base")

In [65]:
trainer.processing_class.save_pretrained("flan-t5-semantic-tagger-base")

('flan-t5-semantic-tagger-base/tokenizer_config.json',
 'flan-t5-semantic-tagger-base/special_tokens_map.json',
 'flan-t5-semantic-tagger-base/spiece.model',
 'flan-t5-semantic-tagger-base/added_tokens.json')

# Test set preparation

In [44]:
data_test

Unnamed: 0,quote,author,categories
179178,The sting of her abandonment had not lessened ...,T.J. Forrester,"love, relationship, suffering"
183253,Everything that falls upon the eye is an appar...,Marilynn Robinson in Housekeeping,"grief, loss, memory, remembering-the-good, ..."
84139,I don't hate Republicans as individuals. But I...,Howard Dean,"politics, republicans"
272877,Think More Not Less.,Jelani Payne,"critical-thinking, perspective"
195518,"Some individuals have the courage to make it, ...",Gino Segrè,"comfortable, dream, inspirational, journey,..."
...,...,...,...
116792,I've found time can heal most anything and you...,Taylor Swift,"fifteen, music, taylor-swift"
256026,"The whole universe is but a huge Symbol of god"".",Thomas Carlyle,"god, sprituality, universe"
390073,"I am so over you, Rejection. You can't get to ...",Buffy Andrews,"quotes, writers-life, writers-quotes, write..."
366700,I try to live my life every day in the present...,Susan Sarandon,"activism, injustice, life, need"


In [45]:
test_dataset = process_data_in_chunks(data_test)

Processing chunks:   0%|          | 0/2 [00:00<?, ?it/s]

In [47]:
test_dataset.save_to_disk("test_dataset")

Saving the dataset (0/1 shards):   0%|          | 0/99942 [00:00<?, ? examples/s]

# Inference and evaluation

In [80]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [81]:
import torch

In [82]:
from tqdm.auto import tqdm

In [83]:
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

In [8]:
def load_model_and_tokenizer(model_path):
    """
    Load the trained model and tokenizer from the specified path
    """
    print(f"Loading model and tokenizer from {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
    
    # Update the model's generation config for better tag generation
    model.config.no_repeat_ngram_size = 2    # Prevents repeating 2-grams
    model.config.repetition_penalty = 2.0    # Penalizes token repetition
    model.config.diversity_penalty = 0.0     # No diversity penalty to trigger beam search
    model.config.num_beam_groups = 1         # Diverse beam groups
    model.config.num_beams = 5               # Beam search for better quality
    
    return model, tokenizer

In [6]:
def predict_tags(model, tokenizer, text, max_length=128):
    """
    Generate tags for a given text using the trained model
    """
    # Prepare the input
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    input_ids = inputs.input_ids
    
    # Move to the same device as model
    device = model.device
    input_ids = input_ids.to(device)
    
    # Generate tags
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            early_stopping=True,
        )
    
    # Decode the generated tokens
    predicted_tags = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Clean up tags - split by comma and strip whitespace
    tags = [tag.strip() for tag in predicted_tags.split(",") if tag.strip()]
    
    # Remove duplicates while preserving order
    unique_tags = []
    for tag in tags:
        if tag not in unique_tags:
            unique_tags.append(tag)
    
    return unique_tags

In [86]:
def clean_and_deduplicate_tags(tags):
    """
    Clean and deduplicate a list of tags
    """
    # Simple clean up - strip whitespace
    cleaned_tags = [tag.strip() for tag in tags if tag.strip()]
    
    # Remove duplicates while preserving order
    unique_tags = []
    for tag in cleaned_tags:
        if tag not in unique_tags and not any(
            tag.endswith(suffix) for suffix in ["-quot", "-s", "-"]
        ):
            unique_tags.append(tag)
    
    return unique_tags

In [87]:
def evaluate_on_test_set(model, tokenizer, test_dataset, num_examples=None):
    """
    Evaluate the model on a test dataset
    """
    # Switch to evaluation mode
    model.eval()
    device = model.device

    results = []
    similarities = []
    jaccard_scores = []

    # Initialize the sentence transformer for semantic similarity
    global sentence_model
    if "sentence_model" not in globals():
        print("Loading sentence transformer model")
        model_name = "all-MiniLM-L6-v2"
        sentence_model = SentenceTransformer(model_name)
        sentence_model.to(device)

    # Limit examples if specified
    if num_examples is not None and num_examples < len(test_dataset):
        test_dataset = test_dataset.select(range(num_examples))

    print(f"Evaluating on {len(test_dataset)} examples...")

    for idx, example in enumerate(tqdm(test_dataset)):
        text = example["input_text"]
        reference_tags = (
            example["target_text"].split(", ")
            if isinstance(example["target_text"], str)
            else example["target_text"]
        )
        reference_tags = clean_and_deduplicate_tags(reference_tags)

        # Generate predictions
        predicted_tags = predict_tags(model, tokenizer, text)

        # Calculate metrics
        # 1. Jaccard similarity
        pred_set = set(predicted_tags)
        ref_set = set(reference_tags)

        jaccard = 0.0
        if pred_set or ref_set:  # Avoid division by zero
            intersection = len(pred_set.intersection(ref_set))
            union = len(pred_set.union(ref_set))
            jaccard = intersection / union if union > 0 else 0.0

        # 2. Semantic similarity using sentence transformer
        pred_text = ", ".join(predicted_tags)
        ref_text = ", ".join(reference_tags)

        # Encode texts
        with torch.no_grad():
            pred_embedding = sentence_model.encode(pred_text, convert_to_tensor=True)
            ref_embedding = sentence_model.encode(ref_text, convert_to_tensor=True)

            # Calculate cosine similarity
            similarity = cosine_similarity(
                pred_embedding.cpu().numpy().reshape(1, -1),
                ref_embedding.cpu().numpy().reshape(1, -1),
            )[0][0]

        # Store results
        result = {
            "text": text[:200] + "...",  # Truncate for display
            "reference_tags": reference_tags,
            "predicted_tags": predicted_tags,
            "jaccard_score": jaccard,
            "semantic_similarity": similarity,
        }

        results.append(result)
        jaccard_scores.append(jaccard)
        similarities.append(similarity)

        # Print examples (first 3)
        if idx < 3:
            print(f"\nExample {idx}:")
            print(f"Text: {text[:100]}...")
            print(f"Reference tags: {', '.join(reference_tags)}")
            print(f"Predicted tags: {', '.join(predicted_tags)}")
            print(f"Jaccard score: {jaccard:.4f}")
            print(f"Semantic similarity: {similarity:.4f}")

    # Calculate average metrics
    avg_jaccard = sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0
    avg_similarity = sum(similarities) / len(similarities) if similarities else 0

    # Calculate combined score (same as in compute_metrics)
    combined_score = avg_similarity + 2.0 * avg_jaccard

    # Print overall results
    print("\nOverall Results:")
    print(f"Average Jaccard Score: {avg_jaccard:.4f}")
    print(f"Average Semantic Similarity: {avg_similarity:.4f}")
    print(f"Combined Score: {combined_score:.4f}")

    return {
        "results": results,
        "metrics": {
            "jaccard_score": avg_jaccard,
            "semantic_similarity": avg_similarity,
            "combined_score": combined_score,
        },
    }

In [88]:
model_path = "flan-t5-semantic-tagger-base"
model, tokenizer = load_model_and_tokenizer(model_path)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Loading model and tokenizer from flan-t5-semantic-tagger-base


T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

In [89]:
from datasets import load_from_disk

In [90]:
test_dataset = load_from_disk("test_dataset")

In [91]:
# Evaluate on test set
test_results = evaluate_on_test_set(model, tokenizer, test_dataset, num_examples=100)

Evaluating on 100 examples...


  0%|          | 0/100 [00:00<?, ?it/s]


Example 0:
Text: What tags or categories would best describe this quote: "The sting of her abandonment had not lessen...
Reference tags: love, relationship, suffering
Predicted tags: abandonment, t.j.-forrester
Jaccard score: 0.0000
Semantic similarity: 0.2328

Example 1:
Text: What tags or categories would best describe this quote: "Everything that falls upon the eye is an ap...
Reference tags: grief, loss, memory, remembering-the-good, time
Predicted tags: apparition, perishable
Jaccard score: 0.0000
Semantic similarity: 0.2068

Example 2:
Text: What tags or categories would best describe this quote: "I don't hate Republicans as individuals. Bu...
Reference tags: politics, republicans
Predicted tags: America, Hate, Individuals
Jaccard score: 0.0000
Semantic similarity: 0.5124

Overall Results:
Average Jaccard Score: 0.1145
Average Semantic Similarity: 0.4382
Combined Score: 0.6672


In [92]:
import numpy as np
import json


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

In [93]:
with open("evaluation_results_base.json", "w") as f:
    # Convert results to serializable format
    serializable_results = {
        "metrics": test_results["metrics"],
        "examples": [
            {
                "text": r["text"],
                "reference_tags": r["reference_tags"],
                "predicted_tags": r["predicted_tags"],
                "jaccard_score": r["jaccard_score"],
                "semantic_similarity": r["semantic_similarity"],
            }
            for r in test_results["results"][:10]  # Save first 10 examples
        ],
    }
    json.dump(serializable_results, f, indent=2, cls=NumpyEncoder)

print("Evaluation completed and results saved to evaluation_results.json")

Evaluation completed and results saved to evaluation_results.json


## Exporting to HF

In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [2]:
torch.cuda.is_available()

True

In [3]:
# Load the trained small model and tokenizer
model_path = "flan-t5-semantic-tagger-small"
# model = AutoModelForSeq2SeqLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
# tokenizer = AutoTokenizer.from_pretrained(model_path)

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_path, torch_dtype=torch.float32, local_files_only=True, to_pt=True
)
tokenizer = AutoTokenizer.from_pretrained(
    model_path, use_fast=True, local_files_only=True
)


TypeError: T5ForConditionalGeneration.__init__() got an unexpected keyword argument 'to_pt'

In [4]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
model.push_to_hub("fristrup/flan-t5-semantic-tagger-small")
tokenizer.push_to_hub("fristrup/flan-t5-semantic-tagger-small")

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f71b8826850>>
Traceback (most recent call last):
  File "/mnt/d/repos/github/QuoteWeave/models/.venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


NameError: name 'model' is not defined

## Test inference

In [1]:
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [2]:
inference_model = AutoModelForSeq2SeqLM.from_pretrained(
    "fristrup/flan-t5-semantic-tagger-small", torch_dtype=torch.float32, local_files_only=True
)
inference_tokenizer = AutoTokenizer.from_pretrained(
    "fristrup/flan-t5-semantic-tagger-small", use_fast=True, local_files_only=True
)


You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [3]:
import pandas as pd

In [4]:
df = pd.read_csv("hf://datasets/jstet/quotes-500k/quotes.csv")

In [9]:
inference_model, inference_tokenizer = load_model_and_tokenizer("fristrup/flan-t5-semantic-tagger-small")


Loading model and tokenizer from fristrup/flan-t5-semantic-tagger-small


In [14]:
predict_tags(inference_model, inference_tokenizer, create_input_text(df["quote"].iloc[0], df["author"].iloc[0]))

['selfish', 'impatient', 'insecure']