In [1]:
!pip install -q transformers datasets torch tqdm

In [2]:
!pip install -U transformers

Collecting transformers
  Using cached transformers-4.57.0-py3-none-any.whl.metadata (41 kB)
Using cached transformers-4.57.0-py3-none-any.whl (12.0 MB)
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.56.2
    Uninstalling transformers-4.56.2:
      Successfully uninstalled transformers-4.56.2
Successfully installed transformers-4.57.0


In [3]:
# Persona extractor tailored for Personachat (Persona + chat) and ESConv (dialog -> content)

import os, json, torch
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
)
from datasets import Dataset
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

persona_file = "/content/personachat.json"   # your PersonaChat (you provided sample)
esconv_file  = "/content/ESConv.json"       # ESConv (you provided sample)

Using device: cpu


In [4]:
# Load PersonaChat -> history -> persona

def load_personachat_history2persona(path, max_examples=None):
    data = json.load(open(path, "r", encoding="utf-8"))
    examples = []
    for item in data:
        # Persona text: you have "Persona" key in your sample
        persona_text = ""
        # common keys in your sample: "Persona"
        for k in ["Persona", "persona", "personality"]:
            if k in item and item[k]:
                val = item[k]
                if isinstance(val, list):
                    persona_text = " | ".join([v.strip() for v in val if isinstance(v, str) and v.strip()])
                else:
                    persona_text = str(val).strip()
                break

        # chat: your sample has "chat" as newline-separated string
        chat = item.get("chat") or item.get("dialog") or item.get("dialogue") or item.get("utterances") or ""
        turns = []
        if isinstance(chat, str):
            turns = [t.strip() for t in chat.split("\n") if t.strip()]
        elif isinstance(chat, list):
            for u in chat:
                if isinstance(u, str):
                    turns.append(u.strip())
                elif isinstance(u, dict):
                    txt = u.get("content") or u.get("text") or u.get("utterance") or ""
                    if txt:
                        turns.append(str(txt).strip())

        if not persona_text or not turns:
            continue

        src = "history: " + " </s> ".join(turns)
        tgt = persona_text
        examples.append({"source": src, "target": tgt})
        if max_examples and len(examples) >= max_examples:
            break

    print(f"[INFO] Loaded {len(examples)} history->persona examples from {path}")
    return examples

In [5]:
# Quick check on PersonaChat loader
persona_examples = load_personachat_history2persona(persona_file, max_examples=20000)
if len(persona_examples) == 0:
    raise SystemExit("ERROR: No history->persona examples found. Check your personachat.json keys (Persona/chat).")

print("\n=== SAMPLE TRAINING PAIRS (first 5) ===")
for i, ex in enumerate(persona_examples[:5]):
    print(f"\n[{i+1}] SOURCE (trunc): {ex['source'][:200]}")
    print(f"[{i+1}] TARGET (persona): {ex['target'][:200]}")

# Trim for quick runs (optional)
persona_examples = persona_examples[:15000]
n = len(persona_examples)
train_ds = Dataset.from_list(persona_examples[:int(0.9*n)])
val_ds   = Dataset.from_list(persona_examples[int(0.9*n):])

FileNotFoundError: [Errno 2] No such file or directory: '/content/personachat.json'

In [6]:
# Model & Tokenization

model_name = "facebook/bart-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)

def tokenize_fn(batch):
    model_inputs = tokenizer(batch["source"], truncation=True, max_length=512)
    # correct seq2seq target tokenization
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["target"], truncation=True, max_length=128)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Inspect tokenized labels for a small sample to detect -100 or empty targets
sample_check = Dataset.from_list(persona_examples[:4])
sample_tok = sample_check.map(tokenize_fn, batched=True, remove_columns=sample_check.column_names)
print("\n=== TOKENIZED LABELS SAMPLE ===")
for i, item in enumerate(sample_tok):
    labels = item["labels"]
    print(f"Example {i+1}: label tokens len={len(labels)} preview={labels[:20]}")
    if len(labels) == 0 or all([lbl == -100 for lbl in labels]):
        raise SystemExit("ERROR: tokenized labels appear empty or all -100. That means target tokenization failed or target is empty.")

train_tokenized = train_ds.map(tokenize_fn, batched=True, remove_columns=train_ds.column_names)
val_tokenized   = val_ds.map(tokenize_fn,   batched=True, remove_columns=val_ds.column_names)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]


=== TOKENIZED LABELS SAMPLE ===
Example 1: label tokens len=30 preview=[0, 118, 101, 7, 21054, 523, 1611, 4, 939, 101, 7, 213, 8217, 4, 939, 101, 7, 4511, 10, 7323]
Example 2: label tokens len=29 preview=[0, 4783, 3795, 16, 127, 275, 1441, 4, 939, 33, 237, 7502, 4, 939, 679, 14, 9374, 16355, 29, 32]
Example 3: label tokens len=46 preview=[0, 118, 56, 10, 10196, 23, 400, 7364, 94, 363, 4, 939, 173, 25, 10, 1413, 62, 10688, 4, 939]
Example 4: label tokens len=25 preview=[0, 118, 524, 182, 7714, 4, 939, 3568, 9872, 4, 939, 33, 6219, 2549, 4, 939, 657, 44821, 21050, 4]




Map:   0%|          | 0/8045 [00:00<?, ? examples/s]

Map:   0%|          | 0/894 [00:00<?, ? examples/s]

In [7]:
# Sanity generation BEFORE training (should produce something but likely not persona-yet)

def generate_from_text(inp_text, max_len=128):
    inputs = tokenizer(inp_text, return_tensors="pt", truncation=True, max_length=512).to(device)
    model.eval()
    with torch.no_grad():
        ids = model.generate(**inputs, max_length=max_len, num_beams=4, early_stopping=True, min_length=1)
    return tokenizer.decode(ids[0], skip_special_tokens=True).strip()

print("\n=== SANITY GENERATE (pre-train) on first source ===")
test_src = persona_examples[0]["source"]
print("SOURCE (trunc):", test_src[:200])
print("MODEL OUTPUT (pre-train):", repr(generate_from_text(test_src)))


=== SANITY GENERATE (pre-train) on first source ===
SOURCE (trunc): history: hi , how are you doing ? i am getting ready to do some cheetah chasing to stay in shape . </s> you must be very fast . hunting is one of my favorite hobbies . </s> i am ! for my hobby i like 
MODEL OUTPUT (pre-train): 'history: hi , how are you doing ? i am getting ready to do some cheetah chasing to stay in shape .'


In [8]:
# Training (light)

args = Seq2SeqTrainingArguments(
    output_dir="./persona_bart",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    logging_steps=200,
    eval_steps=500,
    save_total_limit=1,
    remove_unused_columns=False,
    report_to="none",
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model)
trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("\n=== START TRAINING (1 epoch quick) ===")
trainer.train()
trainer.save_model("./persona_bart")
tokenizer.save_pretrained("./persona_bart")
print("Training complete.")

print("\n=== SANITY GENERATE (post-train) on first source ===")
print("MODEL OUTPUT (post-train):", repr(generate_from_text(test_src)))

  trainer = Seq2SeqTrainer(



=== START TRAINING (1 epoch quick) ===


Step,Training Loss
200,2.4754
400,2.1248
600,2.0154
800,1.9154
1000,1.8091
1200,1.7839
1400,1.7088
1600,1.678
1800,1.5946
2000,1.5661




Training complete.

=== SANITY GENERATE (post-train) on first source ===
MODEL OUTPUT (post-train): 'i like to remodel homes. my favorite food is meat. i like to hunt. i am in high school.'


In [None]:
# ESConv inference - tailored to your ESConv sample format

def load_esconv(path):
    data = json.load(open(path, "r", encoding="utf-8"))
    if isinstance(data, dict):
        # repository sometimes wraps dialogs in "dialogs" key
        if "dialogs" in data:
            return data["dialogs"]
        for v in data.values():
            if isinstance(v, list):
                return v
        return []
    return data

def extract_texts_from_entry(entry):
    # Your ESConv sample uses key "dialog" which is a list of dicts with "content"
    texts = []
    if isinstance(entry, list):
        for u in entry:
            if isinstance(u, str):
                texts.append(u.strip())
            elif isinstance(u, dict):
                txt = u.get("content") or u.get("text") or u.get("utterance") or ""
                if txt:
                    texts.append(str(txt).strip())
    elif isinstance(entry, dict):
        # check common keys including "dialog"
        for k in ["dialog", "dialogue", "utterances", "conversation", "turns", "history"]:
            if k in entry and entry[k]:
                return extract_texts_from_entry(entry[k])
        # fallback: if entry has speaker/text pairs, try to extract them
        if "content" in entry:
            texts.append(str(entry["content"]).strip())
    return texts

def compose_input(history_texts):
    return "history: " + " </s> ".join(history_texts)

def infer_persona_from_dialogue(dialogue):
    utterances = extract_texts_from_entry(dialogue)
    if not utterances:
        return "", []
    history = [u for u in utterances if u]
    if not history:
        return "", []
    inp = compose_input(history)
    inputs = tokenizer(inp, return_tensors="pt", truncation=True, max_length=512).to(device)
    model.eval()
    with torch.no_grad():
        ids = model.generate(**inputs, max_length=128, num_beams=4, early_stopping=True, min_length=3)
    persona = tokenizer.decode(ids[0], skip_special_tokens=True).strip()

    # persona_list generation (turnwise) - optional/useful
    persona_list = []
    for i in range(3, len(history)):
        inp2 = compose_input(history[:i+1])
        inputs2 = tokenizer(inp2, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            ids2 = model.generate(**inputs2, max_length=128, num_beams=4, early_stopping=True, min_length=3)
        persona_list.append(tokenizer.decode(ids2[0], skip_special_tokens=True).strip())
    return persona, persona_list

esconv_data = load_esconv(esconv_file)
pesconv_output = []
for dlg in tqdm(esconv_data, desc="Generating personas"):
    persona, plist = infer_persona_from_dialogue(dlg)
    newdlg = dlg.copy() if isinstance(dlg, dict) else {"dialog": dlg}
    newdlg["persona"] = persona
    newdlg["persona_list"] = plist
    pesconv_output.append(newdlg)

with open("PESConv.json", "w", encoding="utf-8") as f:
    json.dump(pesconv_output, f, ensure_ascii=False, indent=2)
print("PESConv.json saved!")

Generating personas:   6%|▌         | 74/1300 [12:00<4:19:23, 12.69s/it]

In [None]:
# Print a few samples
print("\n EXAMPLE OUTPUTS ")
for i, d in enumerate(pesconv_output[:3]):
    print(f"\nDialog {i+1}:")
    # show first few utterances
    utterances = extract_texts_from_entry(d)
    print("Utterances (few):", utterances[:4])
    print("Predicted Persona:", repr(d.get("persona","")))
    print("Persona list (few):", d.get("persona_list", [])[:2])

**Fact Memory Class**

In [None]:
import json
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class FactMemory:
    def __init__(self, model_path="./persona_bart", memory_path="user_memory.json"):
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.memory_path = memory_path

        # Load or initialize memory
        if os.path.exists(memory_path):
            self.memory = json.load(open(memory_path, "r", encoding="utf-8"))
        else:
            self.memory = {
                "persona_facts": [],
                "session_history": [],
                "summaries": []
            }

    def save(self):
        with open(self.memory_path, "w", encoding="utf-8") as f:
            json.dump(self.memory, f, ensure_ascii=False, indent=2)

    # Add new persona facts (from extraction)
    def add_persona_facts(self, facts):
        for f in facts:
            if f not in self.memory["persona_facts"]:
                self.memory["persona_facts"].append(f)
        self.save()

    # Record ongoing conversation
    def add_session_turn(self, user, bot):
        self.memory["session_history"].append({"user": user, "bot": bot})
        self.save()

    # Retrieve all memory for dialogue conditioning
    def get_context(self):
        persona = " ".join(self.memory["persona_facts"])
        session = " ".join([f"User: {x['user']} Bot: {x['bot']}" for x in self.memory["session_history"][-5:]])
        return f"Persona: {persona}\nHistory: {session}"

    # Summarize long history (to keep memory concise)
    def summarize_history(self, max_len=10):
        if len(self.memory["session_history"]) > max_len:
            text = " ".join([f"User: {x['user']} Bot: {x['bot']}" for x in self.memory["session_history"]])
            summary = self.summarize_text(text)
            self.memory["summaries"].append(summary)
            self.memory["session_history"] = []
            self.save()
            return summary
        return None

    # Uses model to summarize session history
    def summarize_text(self, text):
        prompt = f"Summarize key user facts and themes: {text}"
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        ids = self.model.generate(**inputs, max_length=100, num_beams=4)
        return self.tokenizer.decode(ids[0], skip_special_tokens=True)


**Integrate Memory with Dialogue Generation**

In [None]:
def generate_response(user_input, fact_memory):
    # Update session history with user input
    fact_memory.add_session_turn(user=user_input, bot="")

    # Prepare context (persona + history)
    context = fact_memory.get_context()
    prompt = f"{context}\nUser: {user_input}\nBot:"

    inputs = fact_memory.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    ids = fact_memory.model.generate(**inputs, max_length=80, num_beams=4)
    reply = fact_memory.tokenizer.decode(ids[0], skip_special_tokens=True)

    # Update memory with generated bot response
    fact_memory.memory["session_history"][-1]["bot"] = reply
    fact_memory.save()
    return reply


**User-Controlled Updates (View, Add, Delete Facts)**

In [None]:
def view_memory(fact_memory):
    print("=== Persona Facts ===")
    for i, f in enumerate(fact_memory.memory["persona_facts"]):
        print(f"{i+1}. {f}")
    print("\n=== Summaries ===")
    for s in fact_memory.memory["summaries"]:
        print("-", s)

def add_fact(fact_memory, new_fact):
    fact_memory.add_persona_facts([new_fact])
    print(f"Added fact: '{new_fact}'")

def delete_fact(fact_memory, index):
    try:
        removed = fact_memory.memory["persona_facts"].pop(index - 1)
        fact_memory.save()
        print(f"Deleted fact: '{removed}'")
    except IndexError:
        print("Invalid index.")
