In [1]:
import sys
sys.path.append('..')

In [3]:
import CAT
import json
import torch
import logging
import datetime
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
!pip install tensorboardX
from tensorboardX import SummaryWriter


Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl.metadata (5.8 kB)
Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
Installing collected packages: tensorboardX
Successfully installed tensorboardX-2.6.2.2


In [11]:
def setuplogger():
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter("[%(levelname)s %(asctime)s] %(message)s")
    handler.setFormatter(formatter)
    root.addHandler(handler)

In [12]:
setuplogger()

In [13]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1ca6dc7fe70>

In [14]:
# tensorboard
log_dir = f"../logs/{datetime.datetime.now().strftime('%Y-%m-%d-%H:%M')}/"
log_dir = f"../logs/"
print(log_dir)
writer = SummaryWriter(log_dir)

../logs/


In [15]:
# choose dataset here
import CAT.strategy


dataset = 'assistment'
# modify config here
config = {
    'learning_rate': 0.0025,
    'batch_size': 2048,
    'num_epochs': 8,
    'num_dim': 1, # for IRT or MIRT
    'device': 'cpu',
    # for NeuralCD
    'prednet_len1': 128,
    'prednet_len2': 64,
    # for BOBCAT
    'policy':'notbobcat',
    'betas': (0.9, 0.999),
    'policy_path': 'policy.pt',
    # for NCAT
    'THRESHOLD' :300,
    'start':0,
    'end':3000
    
}
# fixed test length
test_length = 5
# choose strategies here
#strategies = [CAT.strategy.RandomStrategy(), CAT.strategy.MFIStrategy(), CAT.strategy.KLIStrategy()]
strategies = [CAT.strategy.NCATs()]
# modify checkpoint path here
ckpt_path = '../ckpt/irt.pt'
bobcat_policy_path =config['policy_path']

In [16]:
# read datasets
test_triplets = pd.read_csv(f'dataset/test_triples.csv', encoding='utf-8').to_records(index=False)
concept_map = json.load(open(f'dataset/concept_map.json', 'r'))
concept_map = {int(k):v for k,v in concept_map.items()}
metadata = json.load(open(f'dataset/metadata.json', 'r'))

In [17]:
test_data = CAT.dataset.AdapTestDataset(test_triplets, concept_map,
                                        metadata['num_test_students'], 
                                        metadata['num_questions'], 
                                        metadata['num_concepts'])

In [18]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
for strategy in strategies:
    avg =[]
    model = CAT.model.IRTModel(**config)
    #model = CAT.model.NCDModel(**config)
    model.init_model(test_data)
    model.adaptest_load(ckpt_path)
    test_data.reset()
    print(strategy.name)
    if strategy.name == 'NCAT':
        selected_questions = strategy.adaptest_select(test_data,concept_map,config,test_length)
        for it in range(test_length):
            for student, questions in selected_questions.items():
                test_data.apply_selection(student, questions[it])  
            model.adaptest_update(test_data)
            results = model.evaluate(test_data)
        # log results
            logging.info(f'Iteration {it}')
            for name, value in results.items():
                logging.info(f'{name}:{value}')
        continue
    if strategy.name == 'BOBCAT':
        real = {}
        real_data = test_data.data
        for sid in real_data:
            question_ids = list(real_data[sid].keys())
            real[sid]={}
            tmp={}
            for qid in question_ids:
                tmp[qid]=real_data[sid][qid]
            real[sid]=tmp
    logging.info('-----------')
    logging.info(f'start adaptive testing with {strategy.name} strategy')
    logging.info(f'Iteration 0')
    # evaluate models
    results = model.evaluate(test_data)
    for name, value in results.items():
        logging.info(f'{name}:{value}')
    S_sel ={}
    for sid in range(test_data.num_students):
        key = sid
        S_sel[key] = []
    selected_questions={}
    for it in range(1, test_length + 1):
        logging.info(f'Iteration {it}')
        # select question
        if strategy.name == 'BOBCAT':
            selected_questions = strategy.adaptest_select(model, test_data,S_sel)
            for sid in range(test_data.num_students):
                tmp = {}
                tmp[selected_questions[sid]] = real[sid][selected_questions[sid]]
                S_sel[sid].append(tmp)
        elif it == 1 and strategy.name == 'BECAT Strategy':
            for sid in range(test_data.num_students):
                untested_questions = np.array(list(test_data.untested[sid]))
                random_index = random.randint(0, len(untested_questions)-1)
                selected_questions[sid] = untested_questions[random_index]
                S_sel[sid].append(untested_questions[random_index])
        elif strategy.name == 'BECAT Strategy':    
            selected_questions = strategy.adaptest_select(model, test_data,S_sel)
            for sid in range(test_data.num_students):
                S_sel[sid].append(selected_questions[sid])
        else:
            selected_questions = strategy.adaptest_select(model, test_data)
        for student, question in selected_questions.items():
            test_data.apply_selection(student, question)       
        
        # update models
        model.adaptest_update(test_data)
        # evaluate models
        results = model.evaluate(test_data)
        # log results
        for name, value in results.items():
            logging.info(f'{name}:{value}')
            writer.add_scalars(name, {strategy.name: value}, it)

NCAT
1/910
2/910
3/910
4/910
5/910
6/910
7/910
8/910
9/910
10/910
11/910
12/910
13/910
14/910
15/910
16/910
17/910
18/910
19/910
20/910
21/910
22/910
23/910
24/910
25/910
26/910
27/910
28/910
29/910
30/910
31/910
32/910
33/910
34/910
35/910
36/910
37/910
38/910
39/910
40/910
41/910
42/910
43/910
44/910
45/910
46/910
47/910
48/910
49/910
50/910
51/910
52/910
53/910
54/910
55/910
56/910
57/910
58/910
59/910
60/910
61/910
62/910
63/910
64/910
65/910
66/910
67/910
68/910
69/910
70/910
71/910
72/910
73/910
74/910
75/910
76/910
77/910
78/910
79/910
80/910
81/910
82/910
83/910
84/910
85/910
86/910
87/910
88/910
89/910
90/910
91/910
92/910
93/910
94/910
95/910
96/910
97/910
98/910
99/910
100/910
101/910
102/910
103/910
104/910
105/910
106/910
107/910
108/910
109/910
110/910
111/910
112/910
113/910
114/910
115/910
116/910
117/910
118/910
119/910
120/910
121/910
122/910
123/910
124/910
125/910
126/910
127/910
128/910
129/910
130/910
131/910
132/910
133/910
134/910
135/910
136/910
137/910
138/910