Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I run student predictions? #7

Closed
smr97 opened this issue Jan 1, 2020 · 2 comments
Closed

How do I run student predictions? #7

smr97 opened this issue Jan 1, 2020 · 2 comments

Comments

@smr97
Copy link

smr97 commented Jan 1, 2020

Hey, I am trying to reproduce your results, and am interested in training several students with different number of hidden layers. I want to submit the student predictions on GLUE website. I have been able too train student models with PKD-skip procedure.

My question is, how do I make predictions from the student model? I guess I should change the run_glue_benchmark somehow. Any help in this regard will be appreciated.

@intersun
Copy link
Owner

intersun commented Jan 2, 2020

the script is designed based on my folder structure, the core code should be as below (starts from line 111)

initialized your model with your trained model files

encoder_file, cls_file = encoder_file[0], cls_file[0]                # your trained encoder and cls file
encoder_bert, classifier = init_model(task, output_all_layers, n_layer, config)  
encoder_bert = load_model(encoder_bert, encoder_file, args, 'exact', verbose=True)
classifier = load_model(classifier, cls_file, args, 'exact', verbose=True)

and load and make prediction (starts from line 119)

dev_examples, dev_dataloader, dev_label_ids = get_task_dataloader(task.lower(), 'dev', tokenizer, args, SequentialSampler, args.eval_batch_size)
dev_res = eval_model_dataloader(encoder_bert, classifier, dev_dataloader, args.device, detailed=True, verbose=False)    
dev_pred_label = dev_res['pred_logit'].argmax(1) 
logger.info('for dev, acc = {}, loss = {}'.format(dev_res['acc'], dev_res['loss'])) 
logger.info('debug dev acc = {}'.format((dev_label_ids.numpy() == dev_pred_label).mean()))

Please let me know if it works.

Thanks for your interests.

@smr97
Copy link
Author

smr97 commented Feb 15, 2020

Thank you, it works!

@smr97 smr97 closed this as completed Feb 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants