## Train base

In [9]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from collections import defaultdict


# ======= Config =======
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 64
BATCH_SIZE = 16
LR = 2e-5
EPOCHS = 3
REPLAY_PER_CLASS = 1
SIGMA = 0.1  # For prototypical augmentation

# ======= Tokenizer =======
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# ======= Label Setup =======
with open("/workspaces/HANet/data/event_type.txt") as f:
    all_event_types = [line.strip() for line in f]       ## tất cả event type có sẵn
label2id = {etype: i for i, etype in enumerate(all_event_types)} ## gán event type = 1 id
id2label = {i: t for t, i in label2id.items()}  ## map ngược lại


In [10]:
# ======= Dataset =======
class EventDataset(Dataset):
    def __init__(self, jsonl_file, label2id, tokenizer, max_len=64):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        with open(jsonl_file) as f:
            for line in f:
                item = json.loads(line)
                words = item["words"]
                for evt in item.get("gold_evt_links", []):
                    label = evt["event_type"]
                    if label in label2id:  # chỉ lấy những nhãn hợp lệ
                        self.samples.append((words, label))
        self.label2id = label2id

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        words, label = self.samples[idx]
        sent = " ".join(words)
        inputs = self.tokenizer(sent, padding='max_length', truncation=True,
                                max_length=self.max_len, return_tensors='pt')
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(self.label2id[label], dtype=torch.long)
        }


# ======= Model =======
class HANetSimple(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = out.pooler_output
        return self.classifier(pooled), pooled

# ======= Memory Buffer for Replay =======
def build_exemplar_memory(dataset, per_class=1):
    memory = defaultdict(list)
    for input in dataset:
        label = input["label"].item()
        if len(memory[label]) < per_class:
            memory[label].append(input)
    return memory

# ======= Prototypical Augmentation =======
def augment_feature(rep, num_aug=5, sigma=0.1):
    return [rep + torch.randn_like(rep) * sigma for _ in range(num_aug)]

In [11]:
################## evaluate
import torch
import json
def evaluate(model, dataloader, device=None):
    model.eval()
    correct, total = 0, 0
    device = device or next(model.parameters()).device

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            logits, _ = model(input_ids, attention_mask)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total if total > 0 else 0.0

In [12]:
def train_base(base_jsonl_path):
    print("\n🚀 Training Base Task...")

    dataset = EventDataset(base_jsonl_path, label2id, tokenizer, max_len=MAX_LEN)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    model = HANetSimple(num_classes=len(label2id)).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        model.train()
        for batch in dataloader:
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["label"].to(DEVICE)

            logits, _ = model(input_ids, attention_mask)
            loss = loss_fn(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc = evaluate(model, dataloader)
        print(f"Epoch {epoch+1} - Accuracy: {acc:.4f}")

    return model

In [None]:
def train_incremental(model, tokenizer, label2id, task_id, memory):
    print(f"\n🧩 Incremental Task {task_id}")

    path = f"/workspaces/HANet/datasets/hanet_minimal/incremental_task_{task_id}.jsonl"
    few_dataset = EventDataset(path, label2id, tokenizer, max_len=MAX_LEN)
    few_loader = DataLoader(few_dataset, batch_size=1, shuffle=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = torch.nn.CrossEntropyLoss()

    train_batch = []

     # === Prototypical Augmentation ===
    for few in few_loader:
        input_ids = few["input_ids"].to(DEVICE)
        attention_mask = few["attention_mask"].to(DEVICE)
        label = few["label"].to(DEVICE)

        _, rep = model(input_ids, attention_mask)  # rep: [1, 768]
        reps = augment_feature(rep.squeeze(0), num_aug=5, sigma=SIGMA)

        for r in reps:
            train_batch.append({
                "label": label.item(),
                "aug_feature": r.detach()
            })

    # === Add replay exemplars ===
    for cls, samples in memory.items():
        for ex in samples:
            train_batch.append(ex)

    # === Train ===
    for epoch in range(EPOCHS):
        model.train()
        for b in train_batch:
            label = b["label"]
            if isinstance(label, torch.Tensor):
                label = label.item()
            label = torch.tensor(label, dtype=torch.long).unsqueeze(0).to(DEVICE)

            if "aug_feature" in b:
                # CASE: augment vector, bypass BERT
                feature = b["aug_feature"].unsqueeze(0).float().to(DEVICE)
                logits = model.classifier(feature)
                rep = feature
            else:
                # CASE: regular input replay
                input_ids = b["input_ids"].unsqueeze(0).to(DEVICE)
                attention_mask = b["attention_mask"].unsqueeze(0).to(DEVICE)
                logits, rep = model(input_ids, attention_mask)

            loss = loss_fn(logits, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # === Evaluate ===
    acc = evaluate(model, few_loader)
    print(f"🎯 Task {task_id} Accuracy: {acc:.4f}")

    return model

: 

In [None]:
# Huấn luyện base task
model = train_base("/workspaces/HANet/data/base_task_random_1000.jsonl")

# Tạo memory buffer
base_dataset = EventDataset("/workspaces/HANet/data/base_task_random_1000.jsonl", label2id, tokenizer, MAX_LEN)
memory = build_exemplar_memory(base_dataset, per_class=1)

# Huấn luyện incremental task với augmentation + memory
model = train_incremental(model, tokenizer, label2id, task_id=1, memory=memory)



🚀 Training Base Task...


In [None]:
with open('/workspaces/HANet/data/base_task_random_1000.jsonl', 'r', encoding='utf-8') as f:
    count = sum(1 for _ in f)
print(f"Tổng số bản ghi: {count}")