In [1]:
import sys
sys.path.append('..')

In [2]:
import CAT
import json
import torch
import logging
import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [3]:
def setuplogger():
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter("[%(levelname)s %(asctime)s] %(message)s")
    handler.setFormatter(formatter)
    root.addHandler(handler)

In [4]:
setuplogger()

In [5]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7ff26265fa10>

In [6]:
# tensorboard
log_dir = f"../logs/{datetime.datetime.now().strftime('%Y-%m-%d-%H:%M')}/"
print(log_dir)
writer = SummaryWriter(log_dir)

../logs/2021-03-09-14:51/


In [7]:
# choose dataset here
dataset = 'assistment'
# modify config here
config = {
    'learning_rate': 0.0025,
    'batch_size': 2048,
    'num_epochs': 8,
    'num_dim': 1, # for IRT or MIRT
    'device': 'cpu',
    # for NeuralCD
    'prednet_len1': 128,
    'prednet_len2': 64,
}
# fixed test length
test_length = 5
# choose strategies here
strategies = [CAT.strategy.RandomStrategy(), CAT.strategy.MFIStrategy(), CAT.strategy.KLIStrategy()]
# modify checkpoint path here
ckpt_path = '../ckpt/irt.pt'

In [8]:
# read datasets
test_triplets = pd.read_csv(f'../data/{dataset}/test_triples.csv', encoding='utf-8').to_records(index=False)
concept_map = json.load(open(f'../data/{dataset}/concept_map.json', 'r'))
concept_map = {int(k):v for k,v in concept_map.items()}
metadata = json.load(open(f'../data/{dataset}/metadata.json', 'r'))

In [9]:
test_data = CAT.dataset.AdapTestDataset(test_triplets, concept_map,
                                        metadata['num_test_students'], 
                                        metadata['num_questions'], 
                                        metadata['num_concepts'])

In [10]:
import warnings
warnings.filterwarnings("ignore")

In [11]:
for strategy in strategies:
    model = CAT.model.IRTModel(**config)
    model.init_model(test_data)
    model.adaptest_load(ckpt_path)
    test_data.reset()
    
    logging.info('-----------')
    logging.info(f'start adaptive testing with {strategy.name} strategy')

    logging.info(f'Iteration 0')
    # evaluate models
    results = model.evaluate(test_data)
    for name, value in results.items():
        logging.info(f'{name}:{value}')
        
    for it in range(1, test_length + 1):
        logging.info(f'Iteration {it}')
        # select question
        selected_questions = strategy.adaptest_select(model, test_data)
        for student, question in selected_questions.items():
            test_data.apply_selection(student, question)
        # update models
        model.adaptest_update(test_data)
        # evaluate models
        results = model.evaluate(test_data)
        # log results
        for name, value in results.items():
            logging.info(f'{name}:{value}')
            writer.add_scalars(name, {strategy.name: value}, it)

[INFO 2021-03-09 14:51:04,289] -----------
[INFO 2021-03-09 14:51:04,290] start adaptive testing with Random Select Strategy strategy
[INFO 2021-03-09 14:51:04,291] Iteration 0
[INFO 2021-03-09 14:51:04,308] auc:0.6484533447389293
[INFO 2021-03-09 14:51:04,309] cov:0.0
[INFO 2021-03-09 14:51:04,309] Iteration 1
[INFO 2021-03-09 14:51:04,344] auc:0.6489562662794149
[INFO 2021-03-09 14:51:04,347] cov:0.05801618621590955
[INFO 2021-03-09 14:51:04,349] Iteration 2
[INFO 2021-03-09 14:51:04,382] auc:0.6487346765890865
[INFO 2021-03-09 14:51:04,383] cov:0.11609196598657111
[INFO 2021-03-09 14:51:04,384] Iteration 3
[INFO 2021-03-09 14:51:04,413] auc:0.6500624347642152
[INFO 2021-03-09 14:51:04,413] cov:0.1612808712341023
[INFO 2021-03-09 14:51:04,414] Iteration 4
[INFO 2021-03-09 14:51:04,443] auc:0.6512111930010926
[INFO 2021-03-09 14:51:04,443] cov:0.20574638420300764
[INFO 2021-03-09 14:51:04,444] Iteration 5
[INFO 2021-03-09 14:51:04,473] auc:0.6514404203673256
[INFO 2021-03-09 14:51:04,