Code made for Kaggle challenge: https://www.kaggle.com/competitions/retweet-prediction-challenge-2022/overview
Part of École Polytechnique's (France) course INF554: Introduction to Machine Learning.
Three models are available: No Text MLP ('mlp''
), W2V MLP ('w2v'
, not mentioned in report since it gives worse
results than the simple No Text MLP) and
CafayNet ('conv'
).
All hyperparameters are available and can be changed in the config.yaml
file. Current hyperparameters are the ones
that gave the best result in the leaderboard as explained in the report, for CafayNet.
To create the environment, run conda env create -f environment.yaml
.
Set data set location with conda env config vars set DATASET_PATH=<path_to_csv_files>
.
To track training, tensorboard --logdir lightning_logs --bind_all
.
- Run
python generate_w2v.py
to generate word2vec embeddings. python run_training.py -v <version_name> -m <model_name>
(model_name in['mlp', 'w2v', 'conv']
).- If you wish to run a new training with pre-loaded weights, add the option
-w <path_to_ckpt>
. - To generate submission with a trained model,
python run_prediction.py -m <model_name> -w <path_to_ckpt_folder>
.