Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

タグ数を徐々に増やしながら学習するオプションの追加、persistent_workersに関する軽微なバグ修正 #322

Merged
merged 11 commits into from
Mar 26, 2023

Conversation

u-haru
Copy link
Contributor

@u-haru u-haru commented Mar 24, 2023

token_warmup

token_warmup_minを基準として、token_warmup_stepになるまで徐々にタグ数を増やしながら学習させる実装です。どれだけ効果があるかは未知数ですが、学習したい要素以外のtokenへの影響が多少軽減されるはずです。

torch.utils.data.DataLoaderの関係でエポック毎にしか変更されないため、少ないエポック数だと正常に動作しない可能性があります。

persistent_workersに関する軽微なバグ修正

persistent_workers使用時にcaption_dropout_every_n_epochsを使うと、エポック毎のワーカーのリセットがされない関係でset_current_epoch()が作用しなくなり、caption_dropout_every_n_epochsが使えなくなります(上記のtoken_warmupも同様)。

なので、競合時はargs.persistent_data_loader_workersをFalseに書き変えるようにしました。

思いつきの実装&修正なので、必要なければcloseしてもらって大丈夫です。

@kohya-ss
Copy link
Owner

PRありがとうございます。タグを徐々に増やすのは興味深いですね。ただ効果は限定的な気もしますので、コード追加とメンテナンスコストを考える必要がありそうです。

またpersistent_workers使用時の問題の件ですが、set_current_epoch()自体が呼ばれなくなる、ということでよろしいでしょうか。もしそうだとすると、そこでdatasetのshuffleをしているので別の問題がありそうです。

ただ手元で試したところ、persistent_data_loader_workersを指定しても、メソッドは呼ばれているようでした。何か呼ばれなくなる追加の条件等があるのでしょうか。お教えいただければ幸いです。

@u-haru
Copy link
Contributor Author

u-haru commented Mar 24, 2023

効果はメンテナンスコストに比べると確かに低そうです(実際効果あるのか分かってないので…)
ホントに思いつきだったのでその辺は深く考えてませんでした、すみません

persistent_workersについてですが、
1.タグ数変化の動作を確かめるためにprocess_caption()内でprintしながら確かめていたが、全く変化しなかった
2.set_current_step()等は正しく呼ばれていた
3.persistent_workersを無効化すると、epochの変わり目でタグ数が変化するようになった
といった経緯で気づきました。

恐らくですがtorch.utils.data.DataLoaderがデータセットをコピーしていて、train_dataset_groupDataLoaderから呼ばれるデータセットが別のものを参照しているのではないかと思います。
それで学習時に
・関数が呼ばれてるのはtrain_dataset_group
・実際にデータをロードしてるのはDataLoaderの内部のデータセット
みたいなことになっているのかなと(persistent_workers無効時はエポック毎に再コピーしてる?)。
実際に学習時のtrain_dataset_group.set_current_epoch()等をコメントアウトしてから学習直前にdel train_dataset_groupをしてみたのですが、問題なく動作しちゃってるので…
DataLoaderは動的なデータセットが想定されてないのかもしれません。

@kohya-ss
Copy link
Owner

効果はメンテナンスコストに比べると確かに低そうです(実際効果あるのか分かってないので…)
ホントに思いつきだったのでその辺は深く考えてませんでした、すみません

いえ、直感的には効果ありそうですし(学習率のwarm upと組み合わせても面白いかもしれません)、コストもそこまで大きくないのですが、一度オプションを追加してしまうと削除はなかなか難しいので……。比較結果などあると良いのですが……。

またpersistent_workersとDataLoaderについての詳細、ありがとうございます。お書きいただいた内容を元に改めてこちらでも確認してみましたが、たしかに実際のDataLoaderから呼ばれるdatasetのインスタンスでは、set_current_epochが呼ばれていないようでした。

恐らくですがtorch.utils.data.DataLoaderがデータセットをコピーしていて、train_dataset_groupDataLoaderから呼ばれるデータセットが別のものを参照しているのではないかと思います。

こちらのご推測が正しいようです。

ただWindowsではpersistent_data_loader_workersを指定しないとepochの切り替わりが極めて遅くなりますし、また単にデータを繰り返して1 epochを長くする方法では(このPRを含めて)epoch単位で何かする機能に影響が出ますので、何かうまい方法を考える必要がありそうです。
(なんらかの方法でepochの切り替わりをコピーされたdataset側に伝えるなど。)

@u-haru
Copy link
Contributor Author

u-haru commented Mar 25, 2023

multiprocessingのValueを使ってcollater_fn側から送信することで、一応stepとepochの両方をステップ毎に送ることが出来るようになりました。ただこの実装は結構無理やりな気もします。

(あとcommitミスしてログが汚くなってしまいました、すみません)

@kohya-ss
Copy link
Owner

ありがとうございます! collater_fnとValueを使うことでプロセス間で値の受け渡しができているようです。PythonにもPyTorchにもそこまで明るくないのでたいへん助かります。

また他のスクリプトへの対応もありがとうございます。タグ数の増加はむしろfine tuningの方が有効な気がしますね。

このままマージさせていただきます。

@kohya-ss kohya-ss merged commit 4f42f75 into kohya-ss:dev Mar 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants