In [1]:
import torch
import torch.nn as nn

import pandas as pd
import numpy as np
import random

from model_new import BertForTokenClassification
import utils.NERutils as nu
import utils.query_funcs as q

from transformers import AutoConfig, AutoTokenizer

from torch.utils.data import DataLoader, SubsetRandomSampler


### Link for inspiration

https://www.scaleway.com/en/blog/active-learning-pytorch/

In [2]:
# Define tokenizer
bert_model_name = "bert-base-multilingual-cased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)

#### Load datasets

In [3]:
train_path = "data/train.parquet"
dev_path = "data/dev.parquet"
test_path = "data/test.parquet"

In [4]:
filter = 'Legal'

In [5]:
train_dataset = nu.NERdataset(dataset_path=train_path, tokenizer=bert_tokenizer, filter=filter)
#dev_dataset = nu.NERdataset(dataset_path=dev_path, tokenizer=bert_tokenizer)
test_dataset = nu.NERdataset(dataset_path=test_path, tokenizer=bert_tokenizer, filter=filter)

['O', 'O', 'DATE', 'DATE', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'EVENT', 'EVENT', 'EVENT', 'EVENT', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O

In [1]:
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

NameError: name 'DataLoader' is not defined

#### Get pretrained model

In [9]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [10]:
# Config
bert_model_name = "bert-base-multilingual-cased"
bert_config = AutoConfig.from_pretrained(
    bert_model_name, 
    num_labels=len(train_dataset.tags), 
    id2label=train_dataset.index2tag, 
    label2id=train_dataset.tag2index
)

model = BertForTokenClassification.from_pretrained(bert_model_name, config=bert_config, tags=train_dataset.tags).to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# Load model
model.load_state_dict(torch.load("test_model", map_location=device))

<All keys matched successfully>

### Active learning

In [12]:
def random_query(data_loader, query_size=10):

    sample_idx = []

    for batch in data_loader:

        _, _, idx = batch
        sample_idx.extend(idx.tolist())

        if len(sample_idx) >= query_size:
            break
        
    return sample_idx[0:query_size]

In [13]:
def query_the_oracle(model, device, dataset, query_size=10, query_strategy='random', 
                     interactive=True, pool_size=0, batch_size=16, num_workers=0):
    
    unlabeled_idx = np.nonzero(dataset.unlabeled_mask)[0]

    # Pool based sampeling
    if pool_size > 0:
        pool_idx = random.sample(range(1, len(unlabeled_idx)), pool_size)
        pool_loader = DataLoader(dataset, batcg_size=batch_size, num_workers=num_workers,
                                 sampler=SubsetRandomSampler(unlabeled_idx[pool_idx]))
    else:
        pool_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
                                 sampler=SubsetRandomSampler(unlabeled_idx))
    
    # Strategies
    if query_strategy == 'margin':
        #sample_idx = margin_query(model, device, pool_loader, query_size)
        print("Method not implemented yet")
        return
    else:
        sample_idx = q.random_query(pool_loader, query_size)
    
    # Move observation to the pool of labeled samples
    for sample in sample_idx:
        dataset.unlabeled_mask[sample] = 0

In [14]:
num_queries = 10
batch_size = 16
query_size = 5
query_strategy='random'
pool_size=10

num_epochs = 1
learning_rate = 1e-05
optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)

for query in range(num_queries):

    query_the_oracle(model, device, train_dataset, query_size, query_strategy, pool_size)

    labeled_idx = np.where(train_dataset.unlabeled_mask == 0)[0]
    labeled_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0, sampler=SubsetRandomSampler(labeled_idx))

    # train model
    previous_test_acc = 0
    current_test_acc = 1
    while current_test_acc > previous_test_acc:
        previous_test_acc = current_test_acc
        model.fit(num_epochs, labeled_loader, device, optimizer)
        model.test(device, test_loader)
        
        train_loss = model.train_loss[-1]
        val_acc = model.train_acc[-1]

    # test model

KeyError: 10313

In [15]:
np.where(train_dataset.unlabeled_mask == 0)[0]

array([], dtype=int64)

In [16]:
num_queries = 10
batch_size = 16
query_size = 5
query_strategy='random'
pool_size=10

In [17]:
query_the_oracle(model, device, train_dataset, query_size, query_strategy, pool_size)

KeyError: 8813

In [None]:
def query_the_oracle(model, device, dataset, query_size=10, query_strategy='random', 
                     interactive=True, pool_size=0, batch_size=16, num_workers=0):
    
    unlabeled_idx = np.nonzero(dataset.unlabeled_mask)[0]

    if pool_size > 0:
        pool_idx = random.sample(range(1, len(unlabeled_idx)), pool_size)
        pool_loader = DataLoader(dataset, batcg_size=batch_size, num_workers=num_workers,
                                 sampler=SubsetRandomSampler(unlabeled_idx[pool_idx]))
    else:
        pool_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
                                 sampler=SubsetRandomSampler(unlabeled_idx))
        
    if query_strategy == 'margin':
        #sample_idx = margin_query(model, device, pool_loader, query_size)
        print("Method not implemented yet")
        return
    else:
        sample_idx = q.random_query(pool_loader, query_size)
    
    # Move observation to the pool of labeled samples
    for sample in sample_idx:
        dataset.unlabeled_mask[sample] = 0

In [23]:
unlabeled_idx = np.nonzero(train_dataset.unlabeled_mask)[0]
unlabeled_idx

array([    0,     1,     2, ..., 11759, 11760, 11761], dtype=int64)

In [25]:
pool_idx = random.sample(range(1, len(unlabeled_idx)), pool_size)
pool_idx

[7178, 168, 9646, 5041, 3309, 1823, 6997, 7592, 5852, 3306]

In [27]:
pool_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=0,
                                 sampler=SubsetRandomSampler(unlabeled_idx[pool_idx]))

In [28]:
pool_loader

<torch.utils.data.dataloader.DataLoader at 0x2baf35c8710>

In [42]:
for batch in test_dataset:
    print(len(batch))

5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5
5


KeyError: 239

In [29]:
q.random_query(pool_loader, query_size)

KeyError: 7592

In [36]:
indexes = []
for batch in pool_loader:
    #print(batch)
    #print(batch["index"].tolist())
    #indexes.extend(batch["index"].tolist())

    #if len(indexes) >= query_size:
    #
    break

KeyError: 7178

In [33]:
pool_loader

<torch.utils.data.dataloader.DataLoader at 0x2baf35c8710>