In [None]:
import json

def load_and_inspect_data(file_path, n=5, max_field_length=50):
    with open(file_path, 'r') as f:
        data = json.load(f)
    print(f"Data type: {type(data)}")
    
    if isinstance(data, list):
        print(f"Number of records: {len(data)}")
        print("Sample records:")
        for i, record in enumerate(data[:n]):
            print(f"Record {i + 1}:")
            for key, value in record.items():
                if isinstance(value, str) and len(value) > max_field_length:
                    value = value[:max_field_length] + "..."
                print(f"  {key}: {value}")
            print()
    elif isinstance(data, dict):
        print(f"Keys: {list(data.keys())}")
        for key, entries in data.items():
            print(f"Sample from {key}:")
            if isinstance(entries, list):
                for i, record in enumerate(entries[:n]):
                    print(f"  Record {i + 1}:")
                    for k, v in record.items():
                        if isinstance(v, str) and len(v) > max_field_length:
                            v = v[:max_field_length] + "..."
                        print(f"    {k}: {v}")
                    print()
            else:
                print(f"  {entries}")
    else:
        print("Unknown data structure.")
    return data

data_file = 'all_data/www_2s_new.json'
dataset = load_and_inspect_data(data_file, n=2, max_field_length=30)


In [76]:
def preprocess_and_extract_train_data(data, max_field_length=50, n=2):

    train_data = []
    for i, record in enumerate(data[:n]):  
        print(f"Record {i + 1}:")
        if "train" in record:
            current_train_data = record["train"]
            train_data.extend(current_train_data)
            for j, train_item in enumerate(current_train_data[:n]):  
                print(f"  Train Item {j + 1}:")
                example_id = train_item.get("example_id", "N/A")
                print(f"    Example ID: {example_id}")
                
                stories = train_item.get("stories", [])
                for k, story in enumerate(stories[:n]):  
                    story_id = story.get("story_id", "N/A")
                    sentences = story.get("sentences", [])
                    truncated_sentences = [s[:max_field_length] + "..." if len(s) > max_field_length else s for s in sentences]
                    print(f"      Story {k + 1} (ID: {story_id}): Sentences: {truncated_sentences}")
        else:
            print("  No 'train' key found in record.")
    return train_data
train_data = preprocess_and_extract_train_data(dataset, max_field_length=30, n=2)
print(f"Extracted {len(train_data)} train items.")


Record 1:
  Train Item 1:
    Example ID: 0-C0
      Story 1 (ID: 0): Sentences: ['Tom bought a new dustbin for t...', 'Tom threw a broken plate in th...', 'Tom got some soup from the fri...', 'Tom put the soup in the microw...', 'Tom ate the cold soup.']
      Story 2 (ID: 0): Sentences: ['Tom bought a new dustbin for t...', 'Tom threw a broken plate in th...', 'Tom got some soup from the fri...', 'Tom put the soup in the microw...', 'Tom turned on the microwave.']
  Train Item 2:
    Example ID: 0-C1
      Story 1 (ID: 0): Sentences: ['Tom bought a new dustbin for t...', 'Tom threw a broken plate in th...', 'Tom unplugged the microwave.', 'Tom put the soup in the microw...', 'Tom turned on the microwave.']
      Story 2 (ID: 0): Sentences: ['Tom bought a new dustbin for t...', 'Tom threw a broken plate in th...', 'Tom got some soup from the fri...', 'Tom put the soup in the microw...', 'Tom turned on the microwave.']
Record 2:
  Train Item 1:
    Example ID: 0-O0
      Story 1 (ID: 0

In [98]:
def format_meta_tasks_dynamic(data, k_support=3, k_query=2):
    tasks = []

    for item in data:
        # 获取当前记录中的 stories
        stories = item.get("stories", [])
        total_stories = len(stories)

        if total_stories < 2:
            print(f"Skipping task due to insufficient stories: {total_stories} available.")
            continue

        dynamic_k_support = min(k_support, total_stories - 1)
        dynamic_k_query = min(k_query, total_stories - dynamic_k_support)

        support_set = stories[:dynamic_k_support]
        query_set = stories[dynamic_k_support:dynamic_k_support + dynamic_k_query]

        if not support_set or not query_set:
            print(f"Skipping task due to empty support or query set: {total_stories} available.")
            continue

        tasks.append({
            "support": [{"text": " ".join(story["sentences"]), "label": story["plausible"]} for story in support_set],
            "query": [{"text": " ".join(story["sentences"]), "label": story["plausible"]} for story in query_set]
        })

    print(f"Number of tasks: {len(tasks)}")
    return tasks
meta_tasks_dynamic = format_meta_tasks_dynamic(train_data)

if meta_tasks_dynamic:
    print(f"Number of tasks: {len(meta_tasks_dynamic)}")
    print(f"Sample task: {meta_tasks_dynamic[0]}")
else:
    print("No tasks were generated. Check the dataset or task formatting logic.")


Number of tasks: 3129
Number of tasks: 3129
Sample task: {'support': [{'text': 'Tom bought a new dustbin for the kitchen. Tom threw a broken plate in the dustbin. Tom got some soup from the fridge. Tom put the soup in the microwave. Tom ate the cold soup.', 'label': False}], 'query': [{'text': 'Tom bought a new dustbin for the kitchen. Tom threw a broken plate in the dustbin. Tom got some soup from the fridge. Tom put the soup in the microwave. Tom turned on the microwave.', 'label': True}]}


In [105]:
def tokenize_meta_tasks(tasks, tokenizer, max_length=128):
    tokenized_tasks = []

    for task in tasks:
        # 支持集
        support_inputs = tokenizer(
            [example["text"] for example in task["support"]],
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
        support_labels = torch.tensor([example["label"] for example in task["support"]], dtype=torch.long)

        # 查询集
        query_inputs = tokenizer(
            [example["text"] for example in task["query"]],
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
        query_labels = torch.tensor([example["label"] for example in task["query"]], dtype=torch.long)

        tokenized_tasks.append({
            "support": {"inputs": support_inputs, "labels": support_labels},
            "query": {"inputs": query_inputs, "labels": query_labels}
        })

    return tokenized_tasks

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
tokenized_meta_tasks = tokenize_meta_tasks(meta_tasks, tokenizer)
print(f"Number of tokenized tasks: {len(tokenized_meta_tasks)}")
print(f"Sample tokenized task: {tokenized_meta_tasks[0] if tokenized_meta_tasks else 'No tokenized tasks'}")


Number of tokenized tasks: 3129
Sample tokenized task: {'support': {'inputs': {'input_ids': tensor([[  101,  3419,  4149,  1037,  2047,  6497,  8428,  2005,  1996,  3829,
          1012,  3419,  4711,  1037,  3714,  5127,  1999,  1996,  6497,  8428,
          1012,  3419,  2288,  2070, 11350,  2013,  1996, 16716,  1012,  3419,
          2404,  1996, 11350,  1999,  1996, 18302,  1012,  3419,  8823,  1996,
          3147, 11350,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,   

In [108]:
import torch
from transformers import BertForSequenceClassification, AdamW

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [116]:
def train_meta_learning(tasks, model, optimizer, num_epochs=3, print_interval=1):
    model.train()
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        total_loss = 0
        total_accuracy = 0
        task_count = 0

        for task in tasks:
            support_inputs = task["support"]["inputs"]
            support_labels = task["support"]["labels"].long().to(device)

            query_inputs = task["query"]["inputs"]
            query_labels = task["query"]["labels"].long().to(device)

            optimizer.zero_grad()
            support_outputs = model(
                input_ids=support_inputs["input_ids"].to(device),
                attention_mask=support_inputs["attention_mask"].to(device),
                token_type_ids=support_inputs["token_type_ids"].to(device),
            )
            logits = support_outputs.logits
            support_loss = loss_fn(logits, support_labels)
            support_loss.backward()
            optimizer.step()

            with torch.no_grad():
                query_outputs = model(
                    input_ids=query_inputs["input_ids"].to(device),
                    attention_mask=query_inputs["attention_mask"].to(device),
                    token_type_ids=query_inputs["token_type_ids"].to(device),
                )
                query_logits = query_outputs.logits
                query_preds = torch.argmax(query_logits, dim=1)
                query_accuracy = (query_preds == query_labels).float().mean().item()

            total_loss += support_loss.item()
            total_accuracy += query_accuracy
            task_count += 1

        avg_loss = total_loss / task_count
        avg_accuracy = total_accuracy / task_count

        if (epoch + 1) % print_interval == 0:
            print(f"Epoch {epoch + 1}/{num_epochs}, Average Loss: {avg_loss:.4f}, Overall Accuracy: {avg_accuracy:.4f}")
train_meta_learning(tokenized_meta_tasks, model, optimizer, num_epochs=3)


Epoch 1/3, Average Loss: 0.2405, Overall Accuracy: 0.7939
Epoch 2/3, Average Loss: 0.2122, Overall Accuracy: 0.8114
Epoch 3/3, Average Loss: 0.1675, Overall Accuracy: 0.8073
