In [None]:
import json
import os
import random
from collections import defaultdict

# === Cấu hình ===
INPUT_PATH = "/content/drive/MyDrive/Deep Learning/HANet/train.jsonl"  # ← đường dẫn đến file gốc
OUTPUT_DIR = "/content/drive/MyDrive/Deep Learning/HANet/data"  # ← nơi sẽ lưu file output
BASE_TYPE_COUNT = 50
INCREMENTAL_TASKS = 5
SHOT_PER_TYPE = 5


random.seed(42)
os.makedirs(OUTPUT_DIR, exist_ok=True)

with open(INPUT_PATH, "r") as f:
    documents = [json.loads(line) for line in f]

flattened = []

for doc in documents:
    content = doc["content"]
    events = doc.get("events", [])

    # Map sentence_id to list of events
    sent2events = defaultdict(list)
    for event in events:
        for mention in event["mention"]:
            sent_id = mention["sent_id"]
            sent2events[sent_id].append({
                "event_type": event["type"],
                "trigger": mention["trigger_word"],
                "offset": mention["offset"]
            })

    for sent_id, sent in enumerate(content):
        words = sent["tokens"]
        evt_links = []

        for evt in sent2events.get(sent_id, []):
            evt_links.append({
                "event_type": evt["event_type"],
                "trigger": evt["offset"]
            })

        if evt_links:
            flattened.append({
                "words": words,
                "gold_evt_links": evt_links
            })

# Step 2: group by event type
event2examples = defaultdict(list)
for item in flattened:
    for evt in item["gold_evt_links"]:
        event2examples[evt["event_type"]].append(item)

all_event_types = sorted(event2examples.keys())
random.shuffle(all_event_types)

with open(os.path.join(OUTPUT_DIR, "event_types.txt"), "w") as f:
    for etype in all_event_types:
        f.write(etype + "\n")

# Step 3: split base & incremental
base_types = all_event_types[:BASE_TYPE_COUNT]
incre_groups = [all_event_types[i:i+10] for i in range(BASE_TYPE_COUNT, BASE_TYPE_COUNT + 10*INCREMENTAL_TASKS, 10)]

base_data = []
for t in base_types:
    base_data.extend(event2examples[t])
with open(os.path.join(OUTPUT_DIR, "base_task.jsonl"), "w") as f:
    for item in base_data:
        f.write(json.dumps(item) + "\n")
print(f"✅ Base task: {len(base_data)} examples from {len(base_types)} types.")

for i, group in enumerate(incre_groups):
    few_data = []
    for t in group:
        few_data.extend(event2examples[t][:SHOT_PER_TYPE])
    with open(os.path.join(OUTPUT_DIR, f"incremental_task_{i+1}.jsonl"), "w") as f:
        for item in few_data:
            f.write(json.dumps(item) + "\n")
    print(f"✅ Incremental task {i+1}: {len(few_data)} examples from {len(group)} types.")

## Train base

In [None]:
# hanet_training.py with augmentation & replay (HANet-style)
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import json
from collections import defaultdict
import random

# ======= Config =======
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 64
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 3
REPLAY_PER_CLASS = 1
SIGMA = 0.1  # For prototypical augmentation

# ======= Tokenizer =======
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# ======= Label Setup =======
with open("/content/drive/MyDrive/Deep Learning/HANet/data/event_types.txt") as f:
    all_event_types = [line.strip() for line in f]       ## tất cả event type có sẵn
label2id = {etype: i for i, etype in enumerate(all_event_types)} ## gán event type = 1 id
id2label = {i: t for t, i in label2id.items()}  ## map ngược lại


In [None]:
# ======= Dataset =======
class EventDataset(Dataset):
    def __init__(self, jsonl_file, label2id, tokenizer, max_len=64):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        with open(jsonl_file) as f:
            for line in f:
                item = json.loads(line)
                words = item["words"]
                for evt in item.get("gold_evt_links", []):
                    self.samples.append((words, evt["event_type"]))
        self.label2id = label2id

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        words, label = self.samples[idx]
        sent = " ".join(words)
        inputs = self.tokenizer(sent, padding='max_length', truncation=True,
                                max_length=self.max_len, return_tensors='pt')
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "label": torch.tensor(self.label2id[label])
        }


# ======= Model =======
class HANetSimple(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        return self.classifier(pooled), pooled

# ======= Memory Buffer for Replay =======
def build_exemplar_memory(dataset, per_class=1):
    memory = defaultdict(list)
    for input in dataset:
        label = input["label"].item()
        if len(memory[label]) < per_class:
            memory[label].append(input)
    return memory

# ======= Prototypical Augmentation =======
def augment_feature(rep, num_aug=5, sigma=0.1):
    return [rep + torch.randn_like(rep) * sigma for _ in range(num_aug)]

In [None]:
################## evaluate
import torch

def evaluate(model, dataloader, device=None):
    model.eval()
    correct, total = 0, 0
    device = device or next(model.parameters()).device

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            logits, _ = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total if total > 0 else 0.0

In [None]:
def train_one_epoch(model, dataloader, optimizer, loss_fn):
    model.train()
    for i, batch in enumerate(dataloader):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        label = batch["label"].to(DEVICE)

        logits, _ = model(input_ids, attention_mask)
        loss = loss_fn(logits, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 500 == 0:
            print(f"  Step {i}/{len(dataloader)}...")

def train_base(base_jsonl_path, save_path="/content/drive/MyDrive/Deep Learning/HANet/model/hanet_base.pth"):
    print("\n🚀 Training Base Task...")

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    dataset = EventDataset(base_jsonl_path, label2id, tokenizer, max_len=MAX_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = HANetSimple(num_classes=len(label2id)).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        train_one_epoch(model, dataloader, optimizer, loss_fn)
        acc = evaluate(model, dataloader)
        print(f"Epoch {epoch+1} - Accuracy: {acc:.4f}")

    # Save model
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(model.state_dict(), save_path)
    print(f"✅ Model saved to {save_path}")

    return model


In [None]:
model = train_base("/content/drive/MyDrive/Deep Learning/HANet/data/base_task.jsonl")


In [None]:
def train_incremental(model, tokenizer, label2id, task_id, memory):
    print(f"\n🧩 Incremental Task {task_id}")

    path = f"/content/drive/MyDrive/Deep Learning/HANet/data/incremental_task_{task_id}.jsonl"
    few_dataset = EventDataset(path, label2id, tokenizer, max_len=MAX_LEN)
    few_loader = DataLoader(few_dataset, batch_size=1, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    loss_fn = torch.nn.CrossEntropyLoss()

    train_batch = []

    # Prototypical Augmentation
    for few in few_loader:
        input_ids = few["input_ids"].to(DEVICE)
        attention_mask = few["attention_mask"].to(DEVICE)
        label = few["label"].to(DEVICE)

        # Check if label is valid before proceeding
        if label.numel() == 0:  # Check if label tensor is empty
            print(f"Skipping example with empty label in task {task_id}")
            continue  # Skip this example and move to the next one

        _, rep = model(input_ids, attention_mask)
        reps = augment_feature(rep.squeeze(0), num_aug=5, sigma=SIGMA)

        for r in reps:
            train_batch.append({
                "input_ids": torch.zeros_like(input_ids[0]),
                "attention_mask": torch.ones_like(attention_mask[0]),
                "label": label.item(),  # <-- Changed: Using label.item()
                "aug_feature": r.detach()  # float tensor
            })

    # Add replay exemplars
    for cls, samples in memory.items():
        for ex in samples:
            train_batch.append(ex)

    # Handle mix forward for augmented input
    original_forward = model.forward

    def mix_forward(input_ids, attention_mask):
        if input_ids.sum() == 0:
            return model.classifier(attention_mask.unsqueeze(0)), attention_mask.unsqueeze(0)
        else:
            return original_forward(input_ids, attention_mask)

    model.forward = mix_forward

    # Train
    for epoch in range(3):
        for b in train_batch:
            input_ids = b.get("input_ids", torch.zeros(MAX_LEN)).unsqueeze(0).to(DEVICE)
            rep_tensor = b.get("aug_feature")
            if rep_tensor is not None:
                attention_mask = rep_tensor.unsqueeze(0).to(DEVICE)
            else:
                attention_mask = b.get("attention_mask", torch.ones(MAX_LEN)).unsqueeze(0).to(DEVICE)

            # Get label and ensure it's a scalar
            label = b["label"]
            if isinstance(label, torch.Tensor):
                label = label.item()  # <-- Changed: Handling replay label structure
           # label = torch.tensor([label]).to(DEVICE)  # Ensure it's a tensor with batch size 1
            label = torch.tensor(label).to(DEVICE)

            logits, _ = model(input_ids, attention_mask)
            loss = loss_fn(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Evaluate
    acc = evaluate(model, few_loader)
    print(f"🎯 Task {task_id} Accuracy: {acc:.4f}")

    model.forward = original_forward  # restore forward
    return model

In [None]:
base_dataset = EventDataset("/content/drive/MyDrive/Deep Learning/HANet/data/base_task.jsonl", label2id, tokenizer, max_len=MAX_LEN)
memory = build_exemplar_memory(base_dataset)

In [None]:
for task_id in range(1, 6):
  model = train_incremental(model, tokenizer, label2id, task_id, memory)

print("\n✅ Full training (base + 5 incremental tasks) completed!")