<a href="https://colab.research.google.com/github/w-okada/beatrice-trainer-colab/blob/master/BeatriceV2_Trainer_Notebook_rev_02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title 1. Install Modules
!pip install easy-colab-ui-extension > /dev/null 2>&1

from easy_colab_ui_extension.run_and_log import run_and_log
run_and_log("pip install colab-easy-ui2 pyworld==0.3.4", log_file="pip_log.txt", tail_lines=3, expected_lines=500)
run_and_log("git -c core.progress=plain clone --depth 1 --progress https://huggingface.co/fierce-cats/beatrice-trainer", log_file="pip_log.txt", tail_lines=3, expected_lines=500)

import os
working_dir = os.path.join("/","content","beatrice-trainer")
data_dir = os.path.join(working_dir, "data")
output_dir = os.path.join(working_dir, "output")
os.makedirs(data_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

In [None]:
# @title 2. Setting
import torch
from colab_easy_ui.beatrice_trainer_main import wrapped_run_server
from IPython.display import HTML
from IPython.display import display

import portpicker
import time

def get_device():
  if torch.cuda.is_available():
    device = torch.cuda.get_device_name(0)
  else:
    device = "cpu"
  return device
device=get_device()

port = portpicker.pick_unused_port()

wrapped_run_server(port=port, colab=True, extract_to=data_dir, extract_to2=output_dir)
time.sleep(2)
from colab_easy_ui.beatrice_trainer_main import get_html_setup
html = get_html_setup(port, device)
display(HTML(html))

In [None]:
# @title 3. tensorboard
from google.colab.output import eval_js
import os
PORT=18002
# %load_ext tensorboard
# %tensorboard --logdir=output --port {PORT}
os.system(f'tensorboard --logdir={output_dir} --host 0.0.0.0 --port {PORT} &')
proxy = eval_js( "google.colab.kernel.proxyPort(" + str(PORT) + ")" )
print(f"tensorboard launched at {proxy}")

In [None]:
# @title 4. Training

#########################
#### Configの読み込み ###
#########################
import json
import os
def load_config(file_path):
  try:
    with open(file_path, 'r') as f:
      config = json.load(f)
    return config
  except FileNotFoundError:
    print(f"Error: Configuration file not found at {file_path}")
    return None
  except json.JSONDecodeError:
    print(f"Error: Invalid JSON format in {file_path}")
    return None

config = load_config("/content/config.json")  # Replace "config.json" with the actual file path
if config is None:
  pass
else:
  download_count = config["downloadCount"]
  notification = config["notification"]
  trainingSteps = config["trainingSteps"]
  downloadDataForAadditionalTraining = config["downloadDataForAadditionalTraining"]
  print(download_count, notification, trainingSteps, downloadDataForAadditionalTraining)


#########################
#### configの生成     ###
#########################
%cd /content/beatrice-trainer

import shutil
import re

config_path=os.path.join(working_dir, "default_config.json")
shutil.copy2(os.path.join(working_dir, "assets", "default_config.json"), config_path)
def update_n_steps(step:int):

  with open(config_path, "r") as f:
      file_content = f.read()
  modified_content = re.sub(r'"n_steps": 10000,',
                            f'"n_steps": {step},',
                            file_content)


  with open(config_path, "w") as f:
      f.write(modified_content)

update_n_steps(trainingSteps)

#########################
#### resumeの検出     ###
#########################
restored_resume_dir = os.path.join(output_dir, "resume_dir")
if os.path.exists(restored_resume_dir):
  resume=True
  # restored_resume_dirの中身をoutput_dirへコピー
  for f in os.listdir(restored_resume_dir):
    if os.path.isfile(os.path.join(restored_resume_dir, f)):
      source_path = os.path.join(restored_resume_dir, f)
      shutil.copy2(source_path, output_dir) #コピー

else:
  resume=False

#######################
#### トレーニング   ###
#######################
if resume == True:
  !python beatrice_trainer -r -d {data_dir} -o {output_dir} -c {config_path}
else:
  !python beatrice_trainer -d {data_dir} -o {output_dir} -c {config_path}

#############################
#### モデルダウンロード   ###
#############################
import zipfile
import tempfile
from google.colab import files
import shutil
import os
def zip_and_download(selected_folder, zip_filename=None):
    try:
        # 一時的なディレクトリに圧縮ファイルを保存
        with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as temp_file:
            zip_file_path = temp_file.name
        # フォルダを圧縮
        folder_path = os.path.join(output_dir, selected_folder)
        print("compress", folder_path)
        with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
            for root, _, in_files in os.walk(folder_path):
                for file in in_files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, start=output_dir)
                    zipf.write(file_path, arcname)
        if zip_filename is None:
          zip_filename = selected_folder
        !cp {zip_file_path} /content/{zip_filename}.zip


        # ファイルをダウンロード
        files.download(f"/content/{zip_filename}.zip")

    except Exception as e:
        print(f"Error during compression and download: {e}")

folders = sorted([name for name in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, name))], reverse=True)
folders = [name for name in folders if name.startswith("paraphernalia_data")]
for i in range(download_count):
    zip_and_download(folders[i])

#####################################
#### resume用データダウンロード   ###
#####################################
import glob
def download_resume_data():
  resume_dir = os.path.join(output_dir, "resume_dir")
  !rm -rf {resume_dir}
  os.makedirs(resume_dir, exist_ok=True)

  # 最新からひとつ前のチェックポイントを特定
  checkpoint_files = sorted(glob.glob('output/checkpoint_data_*'))
  if len(checkpoint_files) > 1:
    checkpoint = checkpoint_files[-2]
  else:
    print("No checkpoint files found or only one checkpoint file exists.")
  step_num = checkpoint.split("_")[-1].split(".")[0]

  for f in os.listdir(output_dir):
    if os.path.isfile(os.path.join(output_dir, f)):
      if not f.startswith("checkpoint_"):
        print(f)
        source_path = os.path.join(output_dir, f)
        shutil.copy2(source_path, resume_dir) #コピー
  shutil.copy2(checkpoint, f"{resume_dir}/checkpoint_latest.pt")
  zip_and_download("resume_dir", zip_filename=f"resume_step{step_num}")

if downloadDataForAadditionalTraining:
  download_resume_data()





In [None]:
# @title A1. レジューム用データのダウンロード。（設定画面で忘れた場合、手動でこのセルを実行してください。）

def download_resume_data():
  resume_dir = os.path.join(output_dir, "resume_dir")
  !rm -rf {resume_dir}
  os.makedirs(resume_dir, exist_ok=True)

  # 最新からひとつ前のチェックポイントを特定
  checkpoint_files = sorted(glob.glob('output/checkpoint_data_*'))
  if len(checkpoint_files) > 1:
    checkpoint = checkpoint_files[-2]
  else:
    print("No checkpoint files found or only one checkpoint file exists.")
  step_num = checkpoint.split("_")[-1].split(".")[0]
  print(step_num)

  for f in os.listdir(output_dir):
    if os.path.isfile(os.path.join(output_dir, f)):
      if not f.startswith("checkpoint_"):
        print(f)
        source_path = os.path.join(output_dir, f)
        shutil.copy2(source_path, resume_dir) #コピー
  shutil.copy2(checkpoint, f"{resume_dir}/checkpoint_latest.pt")
  zip_and_download("resume_dir", zip_filename=f"resume_step{step_num}")

download_resume_data()