Telegram @zzZaur
Бот получает изображение кошки и выдает изображение, стилизованное под тигра.
CycleGAN-архитектура модели взята из статьи:
-
Мы работаем с изображениями 256х256.
-
Генераторы содержат 3 сверточных слоя и 9 ResNet блоков.
-
Финальный сверточный слой дискриминаторов на выходе имеет не число, а тензор 70х70 (т.е. содержит оценку для разных перекрывающихся частей изображения).
Обучение проводилось на kaggle-датасетах для кошек и тигров.
Из датасетов исключены слишком маленькие изображения (меньше 10 Кбайт). Также отфильтрованы некачественные фото тигров (оригинальный датасет собирался из кадров разных видео).
-
Из датасетов берутся два изображения (один кот и один тигр).
-
Генерируется фейковый тигр (кот после генератора cat2tiger) и фейковый кот (тигр после генератора tiger2cat).
-
Дискриминатор котов обучается на правильно размеченной паре (кот, фейковый кот), дискриминатор тигров обучается на паре (тигр, фейковый тигр). Функция потерь MSE.
-
Фейковый тигр пропускается через tiger2cat и получается кот (будем называть его цикличным). Так же получаем цикличного тигра.
-
Далее обучаем генераторы. Функция потерь состоит из двух частей: первая - MSE от выхода дискриминатора на фейковой картинке до противоположного класса (т.е. пытаемся обмануть дискриминатор) и L1 между оригинальным и цикличным фото.
Чтобы дать фору генератору, для обучения дискриминатора изображения поступают через буфер.
Обучение проводилось в Google Colab (notebook).
Бот написан с помощью модуля aiogram.
Необходимые модули можно подгрузить с помощью команды:
pip install -r requirements.txt
Перед запуском нужно положить токен бота в переменную среды TG_BOT_TOKEN:
export TG_BOT_TOKEN="XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
python app.py
Параметры обученного генератора должны лежать в той же директории, что и бот, в файле model.pt (в colab они сохраняются в файл gen_TT_NNN.pt, где NNN - номер эпохи, т.е. для загрузки в бота его нужно переименовать). Файл хранит словарь, в котором параметры лежат по ключу "state_dict".